diff --git a/examples/sup3rwind/running_sup3r_models.ipynb b/examples/sup3rwind/running_sup3r_models.ipynb index b7479f8823..07577d42fa 100644 --- a/examples/sup3rwind/running_sup3r_models.ipynb +++ b/examples/sup3rwind/running_sup3r_models.ipynb @@ -242,7 +242,6 @@ " 'temporal_avg': False,\n", " }\n", "\n", - "\n", "config = {\n", " 'pass_workers': 1,\n", " 'input_handler_kwargs': {\n", diff --git a/pixi.lock b/pixi.lock index 770288f6f1..5edb78d6ed 100644 --- a/pixi.lock +++ b/pixi.lock @@ -10570,8 +10570,8 @@ packages: requires_python: '>=3.9' - pypi: ./ name: nrel-sup3r - version: 0.2.6.dev40+gcbf2eb538.d20260212 - sha256: 34fdf3835de987425a5895c3bde82593b60d70febbc04b705b8e35b103557138 + version: 0.2.6.dev74+g17ecd9ae1.d20260308 + sha256: 2308c5dfecdbc73a9093af87f6197421a11a154908eaaca2e6e9eb3ab1bc4900 requires_dist: - nrel-rex>=0.2.91 - nrel-phygnn>=0.0.32 @@ -10603,7 +10603,6 @@ packages: - pkginfo>=1.10.0,<2 ; extra == 'build' - twine>=5.0 ; extra == 'build' requires_python: '>=3.9' - editable: true - conda: https://conda.anaconda.org/conda-forge/linux-64/numpy-1.26.4-py311h64a7726_0.conda sha256: 3f4365e11b28e244c95ba8579942b0802761ba7bb31c026f50d1a9ea9c728149 md5: a502d7aad449a1206efb366d6a12c52d diff --git a/pyproject.toml b/pyproject.toml index 47ecae58e4..1aabe5c005 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -301,7 +301,7 @@ xarray = ">=2023.0" [tool.pixi.pypi-dependencies] NREL-sup3r = { path = ".", editable = true } -NREL-rex = { version = ">=0.2.87" } +NREL-rex = { version = ">=0.2.91" } NREL-phygnn = { version = ">=0.0.23" } NREL-gaps = { version = ">=0.6.13" } NREL-farms = { version = ">=1.0.4" } diff --git a/sup3r/bias/base.py b/sup3r/bias/base.py index d7761a8767..25664c714f 100644 --- a/sup3r/bias/base.py +++ b/sup3r/bias/base.py @@ -737,7 +737,7 @@ def _reduce_base_data( ) cs_ratio = ( - daily_reduction.lower() in ('avg', 'average', 'mean') + daily_reduction.lower() in {'avg', 'average', 'mean'} and base_dset == 'clearsky_ratio' ) @@ -756,16 +756,16 @@ def _reduce_base_data( ) assert not np.isnan(base_data).any(), msg - elif daily_reduction.lower() in ('avg', 'average', 'mean'): + elif daily_reduction.lower() in {'avg', 'average', 'mean'}: base_data = df.groupby('date').mean()['base_data'].values - elif daily_reduction.lower() in ('max', 'maximum'): + elif daily_reduction.lower() in {'max', 'maximum'}: base_data = df.groupby('date').max()['base_data'].values - elif daily_reduction.lower() in ('min', 'minimum'): + elif daily_reduction.lower() in {'min', 'minimum'}: base_data = df.groupby('date').min()['base_data'].values - elif daily_reduction.lower() in ('sum', 'total'): + elif daily_reduction.lower() in {'sum', 'total'}: base_data = df.groupby('date').sum()['base_data'].values msg = ( diff --git a/sup3r/bias/bias_transforms.py b/sup3r/bias/bias_transforms.py index d64418b62c..445440ca74 100644 --- a/sup3r/bias/bias_transforms.py +++ b/sup3r/bias/bias_transforms.py @@ -720,6 +720,8 @@ def local_qdm_bc( -------- sup3r.bias.qdm.QuantileDeltaMappingCorrection : Estimate probability distributions required by QDM method + rex.utilities.bc_utils.QuantileDeltaMapping : + Core QDM transformation. Notes ----- @@ -738,11 +740,6 @@ def local_qdm_bc( Also, :class:`rex.utilities.bc_utils.QuantileDeltaMapping` expects params to be 2D (space, N-params). - See Also - -------- - rex.utilities.bc_utils.QuantileDeltaMapping : - Core QDM transformation. - Examples -------- >>> unbiased = local_qdm_bc(biased_array, lat_lon_array, "ghi", "rsds", diff --git a/sup3r/models/__init__.py b/sup3r/models/__init__.py index 0cb5f1bf51..f079fd9310 100644 --- a/sup3r/models/__init__.py +++ b/sup3r/models/__init__.py @@ -7,6 +7,5 @@ from .multi_step import MultiStepGan, MultiStepSurfaceMetGan, SolarMultiStepGan from .solar_cc import SolarCC from .surface import SurfaceSpatialMetModel -from .with_obs import Sup3rGanWithObs SPATIAL_FIRST_MODELS = (MultiStepSurfaceMetGan, SolarMultiStepGan) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index eb819ec9aa..eff2018fb4 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -8,22 +8,23 @@ from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from inspect import signature +from threading import Lock from warnings import warn import numpy as np import pandas as pd import tensorflow as tf from gaps.config import load_config -from phygnn import CustomNetwork from tensorflow.keras import optimizers import sup3r.utilities.loss_metrics +from phygnn import CustomNetwork from sup3r.preprocessing.data_handlers import ExoData from sup3r.preprocessing.utilities import numpy_if_tensor from sup3r.utilities import VERSION_RECORD from sup3r.utilities.utilities import Timer, camel_to_underscore, safe_cast -from .utilities import SUP3R_LAYERS, SUP3R_OBS_LAYERS, TensorboardMixIn +from .utilities import SUP3R_LAYERS, TensorboardMixIn logger = logging.getLogger(__name__) @@ -51,6 +52,7 @@ def __init__(self): self._stdevs = None self._train_record = pd.DataFrame() self._val_record = pd.DataFrame() + self._lock = Lock() def load_network(self, model, name): """Load a CustomNetwork object from hidden layers config, .json file @@ -105,10 +107,9 @@ def _load_model_from_string(self, model, name): return model['meta'][f'config_{name}']['hidden_layers'] msg = ( - 'Could not load model from json config, need ' - '"hidden_layers" key or ' - f'"meta/config_{name}/hidden_layers" ' - ' at top level but only found: {}'.format(model.keys()) + 'Could not load model from json config, need "hidden_layers" key ' + f'or "meta/config_{name}/hidden_layers" at top level but only ' + f'found: {model.keys()}' ) logger.error(msg) raise KeyError(msg) @@ -432,8 +433,7 @@ def get_hr_exo_input(self, hi_res): if len(self.hr_exo_features) == 0: return {} inds = [self.hr_features.index(f) for f in self.hr_exo_features] - exo = tf.gather(hi_res, inds, axis=-1) - exo = tf.expand_dims(exo, axis=-2) + exo = tf.expand_dims(tf.gather(hi_res, inds, axis=-1), axis=-2) exo = dict(zip(self.hr_exo_features, tf.unstack(exo, axis=-1))) return exo @@ -460,33 +460,44 @@ def _combine_loss_input(self, hi_res_true, hi_res_gen): hi_res_gen = tf.concat((hi_res_gen, *exo_data), axis=-1) return hi_res_gen - def _get_loss_inputs(self, hi_res_true, hi_res_gen, loss_func): + def _get_loss_inputs(self, hi_res_gen, hi_res_true, loss_func): """Get inputs for the given loss function according to the required - input features""" + generator output features and ground truth features. If the loss + function doesn't specify required features, this will default to using + all output features that are not exogenous features.""" - msg = ( - f'{loss_func} requires input features: ' - f'{loss_func.input_features}, but these are not found ' - f'in the model output features: {self.hr_out_features}' - ) - if not all( - f in self.hr_out_features for f in loss_func.input_features - ): + gen_feats = getattr(loss_func, 'gen_features', 'all') + true_feats = getattr(loss_func, 'true_features', 'all') + + if gen_feats == 'all': + gen_feats = self.hr_out_features + if true_feats == 'all': + true_feats = self.hr_out_features + + if not all(f in self.hr_out_features for f in gen_feats): + msg = ( + f'{loss_func} requires gen_features: ' + f'{loss_func.gen_features}, but these are not found in the ' + f'high-resolution output features: {self.hr_out_features}' + ) logger.error(msg) raise ValueError(msg) - input_inds = [ - self.hr_out_features.index(f) for f in loss_func.input_features - ] - hr_true = tf.stack( - [hi_res_true[..., idx] for idx in input_inds], - axis=-1, - ) - hr_gen = tf.stack( - [hi_res_gen[..., idx] for idx in input_inds], - axis=-1, - ) - return hr_true, hr_gen + if not all(f in self.hr_features for f in true_feats): + msg = ( + f'{loss_func} requires true_features: ' + f'{loss_func.true_features}, but these are not found ' + f'in the high-resolution features: {self.hr_features}' + ) + logger.error(msg) + raise ValueError(msg) + + gen_inds = [self.hr_features.index(f) for f in gen_feats] + true_inds = [self.hr_features.index(f) for f in true_feats] + + hr_true = tf.gather(hi_res_true, true_inds, axis=-1) + hr_gen = tf.gather(hi_res_gen, gen_inds, axis=-1) + return hr_gen, hr_true def get_loss_fun(self, loss): """Get full, possibly multi-term, loss function from the provided str @@ -519,21 +530,24 @@ def get_loss_fun(self, loss): loss_funcs = [self._get_loss_fun({ln: loss[ln]}) for ln in lns] weights = copy.deepcopy(loss).pop('term_weights', [1.0] * len(lns)) - def loss_fun(hi_res_true, hi_res_gen): + def loss_fun(hi_res_gen, hi_res_true): loss_details = {} loss = 0 for i, (ln, loss_func) in enumerate(zip(lns, loss_funcs)): - if ( - not hasattr(loss_func, 'input_features') - or loss_func.input_features == 'all' - ): - val = loss_func(hi_res_true, hi_res_gen) - else: - hr_true, hr_gen = self._get_loss_inputs( - hi_res_true, hi_res_gen, loss_func - ) - val = loss_func(hr_true, hr_gen) + hr_gen, hr_true = self._get_loss_inputs( + hi_res_gen, hi_res_true, loss_func + ) + val = loss_func(hr_gen, hr_true) loss_details[camel_to_underscore(ln)] = val + tf.debugging.assert_all_finite( + val, + message=( + f'NaN or Inf values found for loss term "{ln}" with ' + f'value {val} when running loss function {loss_func} ' + f'generated tensor of shape {hi_res_gen.shape} and ' + f'true tensor of shape {hi_res_true.shape}' + ), + ) loss += weights[i] * val return loss, loss_details @@ -1033,24 +1047,39 @@ def run_exo_layer(self, layer, input_array, exogenous_data, norm_in=True): of shape: (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features) + exogenous_data : dict | ExoData + Special dictionary (class:`ExoData`) of exogenous feature data with + entries describing whether features should be combined at input, a + mid network layer, or with output. This doesn't have to include the + 'model' key since this data is for a single step model. + norm_in : bool + Flag to normalize low_res input data if the self._means, + self._stdevs attributes are available. The generator should always + received normalized data with mean=0 stdev=1. This also normalizes + exogenous data. """ feat_stack = [] extras = [] features = getattr(layer, 'features', [layer.name]) exo_features = getattr(layer, 'exo_features', []) - is_obs_layer = isinstance(layer, SUP3R_OBS_LAYERS) for feat in features + exo_features: - missing_obs = feat in features and feat not in exogenous_data - if is_obs_layer and missing_obs: + missing_feat = feat not in exogenous_data + if missing_feat and '_obs' in feat: msg = ( f'{feat} does not match any features in exogenous_data ' - f'({list(exogenous_data)}). Will run without this ' - 'observation feature.' + f'({list(exogenous_data)}). Will try to run without this ' + 'feature.' ) logger.warning(msg) continue - msg = f'exogenous_data is missing required feature "{feat}"' - assert feat in exogenous_data, msg + if missing_feat: + msg = ( + f'{feat} does not match any features in exogenous_data ' + f'({list(exogenous_data)}). This feature is required for ' + f'layer {layer.name}.' + ) + logger.error(msg) + raise KeyError(msg) exo = exogenous_data.get_combine_type_data(feat, 'layer') exo = self._reshape_norm_exo( input_array, @@ -1088,7 +1117,7 @@ def generate( Flag to normalize low_res input data if the self._means, self._stdevs attributes are available. The generator should always received normalized data with mean=0 stdev=1. This also normalizes - hi_res_topo. + exogenous data. un_norm_out : bool Flag to un-normalize synthetically generated output data to physical units @@ -1203,8 +1232,9 @@ def _tf_generate(self, low_res, hi_res_exo=None): else: hi_res = layer(hi_res) except Exception as e: - msg = 'Could not run layer #{} "{}" on tensor of shape {}'.format( - layer_num, layer, hi_res.shape + msg = ( + f'Could not run layer #{layer_num} "{layer}" on tensor ' + f'of shape {hi_res.shape}' ) logger.error(msg) raise RuntimeError(msg) from e @@ -1218,7 +1248,8 @@ def _get_hr_exo_and_loss( **calc_loss_kwargs, ): """Get high-resolution exogenous data, generate synthetic output, and - compute loss.""" + compute loss. All hr_exo_features are extracted from hi_res_true and + added to exo_data.""" hi_res_exo = self.get_hr_exo_input(hi_res_true) hi_res_gen = self._tf_generate(low_res, hi_res_exo) loss, loss_details = self.calc_loss( @@ -1226,6 +1257,27 @@ def _get_hr_exo_and_loss( ) return loss, loss_details, hi_res_gen, hi_res_exo + @tf.function + def _tf_get_single_grad( + self, + low_res, + hi_res_true, + training_weights, + **calc_loss_kwargs, + ): + """Compiled per-batch gradient step used by :meth:`get_single_grad`. + + Keeping this method tensor-only allows graph compilation while + :meth:`get_single_grad` continues to handle locks and device placement. + """ + with tf.GradientTape(watch_accessed_variables=False) as tape: + tape.watch(training_weights) + loss, loss_details, _, _ = self._get_hr_exo_and_loss( + low_res, hi_res_true, **calc_loss_kwargs + ) + grad = tape.gradient(loss, training_weights) + return grad, loss_details + def get_single_grad( self, low_res, @@ -1266,15 +1318,14 @@ def get_single_grad( loss_details : dict Namespace of the breakdown of loss components """ - with ( - tf.device(device_name), - tf.GradientTape(watch_accessed_variables=False) as tape, - ): - tape.watch(training_weights) - loss, loss_details, _, _ = self._get_hr_exo_and_loss( - low_res, hi_res_true, **calc_loss_kwargs - ) - grad = tape.gradient(loss, training_weights) + with self._lock: + with tf.device(device_name): + grad, loss_details = self._tf_get_single_grad( + low_res, + hi_res_true, + training_weights, + **calc_loss_kwargs, + ) return grad, loss_details @abstractmethod diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 0fa852d288..fa0409f55b 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -38,6 +38,7 @@ def __init__( stdevs=None, default_device=None, name=None, + sparse_disc=False, ): """ Parameters @@ -97,6 +98,15 @@ def __init__( (this was tested as most efficient given the custom multi-gpu strategy developed in self.run_gradient_descent()). Examples: "/gpu:0" or "/cpu:0" + sparse_disc : bool + Flag to indicate if the discriminator can handle sparse input data. + If False, the discriminator will expect input data with no missing + values. If True, the discriminator will be able to handle input + data with missing values, which may be the case when using + observations for training. Note that if True, the discriminator + model architecture should be designed to handle sparse data (e.g. + by using masking layers or other techniques). + name : str | None Optional name for the GAN. """ @@ -130,6 +140,7 @@ def __init__( self._means = means self._stdevs = stdevs + self._sparse_disc = sparse_disc def save(self, out_dir): """Save the GAN with its sub-networks to a directory. @@ -298,7 +309,12 @@ def _tf_discriminate(self, hi_res): out : np.ndarray Discriminator output logits """ - out = self.discriminator.layers[0](hi_res) + hr = ( + hi_res + if self._sparse_disc + else tf.gather(hi_res, indices=self.hr_out_features_ind, axis=-1) + ) + out = self.discriminator.layers[0](hr) layer_num = 1 try: for i, layer in enumerate(self.discriminator.layers[1:]): @@ -392,7 +408,7 @@ def weights(self): """ return self.generator_weights + self.discriminator_weights - def init_weights(self, lr_shape, hr_shape, device=None): + def init_weights(self, lr_shape, hr_shape, train_disc=False, device=None): """Initialize the generator and discriminator weights with device placement. @@ -406,12 +422,17 @@ def init_weights(self, lr_shape, hr_shape, device=None): Shape of one batch of high res input data for sup3r resolution. Note that the batch size (axis=0) must be included, but the actual batch size doesnt really matter. + train_disc : bool + Whether to initialize the discriminator weights. If False, only the + generator weights will be initialized. device : str | None Option to place model weights on a device. If None, self.default_device will be used. """ - if not self.generator_weights: + no_disc_weights = train_disc and not self.discriminator_weights + no_gen_weights = not self.generator_weights + if no_disc_weights or no_gen_weights: if device is None: device = self.default_device @@ -421,12 +442,12 @@ def init_weights(self, lr_shape, hr_shape, device=None): low_res = tf.cast(np.ones(lr_shape), dtype=tf.float32) hi_res = tf.cast(np.ones(hr_shape), dtype=tf.float32) - hr_exo_shape = hr_shape[:-1] + (1,) + hr_exo_shape = (*hr_shape[:-1], 1) hr_exo = tf.cast(np.ones(hr_exo_shape), dtype=tf.float32) with tf.device(device): hr_exo_data = {} - for feature in self.hr_exo_features + self.obs_features: + for feature in self.hr_exo_features: hr_exo_data[feature] = hr_exo out = self._tf_generate(low_res, hr_exo_data) msg = ( @@ -435,7 +456,9 @@ def init_weights(self, lr_shape, hr_shape, device=None): f'{len(self.hr_out_features)}' ) assert out.shape[-1] == len(self.hr_out_features), msg - _ = self._tf_discriminate(hi_res) + + if train_disc: + _ = self._tf_discriminate(hi_res) @staticmethod def get_weight_update_fraction( @@ -494,14 +517,7 @@ def calc_loss_gen_content(self, hi_res_true, hi_res_gen): 0D tensor generator model loss for the content loss comparing the hi res ground truth to the hi res synthetically generated output. """ - slc = ( - slice(0, None) - if len(self.hr_exo_features) == 0 - else slice(0, -len(self.hr_exo_features)) - ) - # gen is first since loss can included regularizers which just - # apply to generator output - return self.loss_fun(hi_res_gen[..., slc], hi_res_true[..., slc]) + return self.loss_fun(hi_res_gen, hi_res_true) @staticmethod @tf.function @@ -517,12 +533,6 @@ def calc_loss_disc(disc_out_true, disc_out_gen): encourages the generator to produce output which is "more realistic" than the true high-res data. - References - ---------- - .. [Wang2018] Wang, Xintao, et al. "Esrgan: Enhanced super-resolution - generative adversarial networks." Proceedings of the European - conference on computer vision (ECCV) workshops. 2018. - Parameters ---------- disc_out_true : tf.Tensor @@ -537,6 +547,12 @@ def calc_loss_disc(disc_out_true, disc_out_gen): loss_disc : tf.Tensor 0D tensor discriminator model loss for either the spatial or temporal component of the super resolution generated output. + + References + ---------- + .. [Wang2018] Wang, Xintao, et al. "Esrgan: Enhanced super-resolution + generative adversarial networks." Proceedings of the European + conference on computer vision (ECCV) workshops. 2018. """ true_logits = disc_out_true - tf.reduce_mean(disc_out_gen) fake_logits = disc_out_gen - tf.reduce_mean(disc_out_true) @@ -606,22 +622,6 @@ def update_adversarial_weights( return weight_gen_advers - @staticmethod - def check_batch_handler_attrs(batch_handler): - """Not all batch handlers have the following attributes. So we perform - some sanitation before sending to `set_model_params`""" - return { - k: getattr(batch_handler, k, None) - for k in [ - 'smoothing', - 'lr_features', - 'hr_exo_features', - 'hr_out_features', - 'smoothed_features', - ] - if hasattr(batch_handler, k) - } - def train( self, batch_handler, @@ -725,12 +725,8 @@ def train( self._init_tensorboard_writer(out_dir) self.set_norm_stats(batch_handler.means, batch_handler.stds) - params = self.check_batch_handler_attrs(batch_handler) self.set_model_params( - input_resolution=input_resolution, - s_enhance=batch_handler.s_enhance, - t_enhance=batch_handler.t_enhance, - **params, + input_resolution=input_resolution, batch_handler=batch_handler ) epochs = list(range(n_epoch)) @@ -881,28 +877,36 @@ def calc_loss( logger.error(msg) raise RuntimeError(msg) - disc_out_true = self._tf_discriminate(hi_res_true) - disc_out_gen = self._tf_discriminate(hi_res_gen) - loss_details = {} loss = None + disc_out_true = None + disc_out_gen = None + loss_gen_advers = None - if compute_disc or train_disc: + if train_disc or compute_disc: + disc_out_true = self._tf_discriminate(hi_res_true) + disc_out_gen = self._tf_discriminate(hi_res_gen) loss_details['loss_disc'] = self.calc_loss_disc( disc_out_true=disc_out_true, disc_out_gen=disc_out_gen ) + if train_gen and compute_disc: + loss_gen_advers = self.calc_loss_disc( + disc_out_true=disc_out_gen, disc_out_gen=disc_out_true + ) + loss_details['loss_gen_advers'] = loss_gen_advers + if train_gen: loss_gen_content, loss_gen_content_details = ( self.calc_loss_gen_content(hi_res_true, hi_res_gen) ) - loss_gen_advers = self.calc_loss_disc( - disc_out_true=disc_out_gen, disc_out_gen=disc_out_true + loss = ( + loss_gen_content + if loss_gen_advers is None + else loss_gen_content + weight_gen_advers * loss_gen_advers ) - loss = loss_gen_content + weight_gen_advers * loss_gen_advers loss_details['loss_gen'] = loss loss_details['loss_gen_content'] = loss_gen_content - loss_details['loss_gen_advers'] = loss_gen_advers loss_details.update(loss_gen_content_details) elif train_disc: @@ -971,7 +975,7 @@ def _train_batch( Flag whether to train the discriminator for this set of epochs only_disc : bool Flag whether to only train the discriminator for this set of epochs - gen_too_good : bool + disc_too_good : bool Flag whether to skip training the discriminator and only train the generator, due to superior performance, for this batch. weight_gen_advers : float @@ -1140,11 +1144,7 @@ def _train_epoch( Namespace of the breakdown of loss components """ lr_shape, hr_shape = batch_handler.shapes - self.init_weights(lr_shape, hr_shape) - - self.init_weights( - (1, *batch_handler.lr_shape), (1, *batch_handler.hr_shape) - ) + self.init_weights(lr_shape, hr_shape, train_disc=train_disc) disc_th_low = np.min(disc_loss_bounds) disc_th_high = np.max(disc_loss_bounds) diff --git a/sup3r/models/conditional.py b/sup3r/models/conditional.py index a37cd8ff9c..c6d1f4e7ff 100644 --- a/sup3r/models/conditional.py +++ b/sup3r/models/conditional.py @@ -423,16 +423,11 @@ def train( self._init_tensorboard_writer(out_dir) self.set_norm_stats(batch_handler.means, batch_handler.stds) - self.set_model_params( - input_resolution=input_resolution, - s_enhance=batch_handler.s_enhance, - t_enhance=batch_handler.t_enhance, - smoothing=batch_handler.smoothing, - lr_features=batch_handler.lr_features, - hr_exo_features=batch_handler.hr_exo_features, - hr_out_features=batch_handler.hr_out_features, - smoothed_features=batch_handler.smoothed_features, - ) + lower_models = getattr(batch_handler, 'lower_models', {}) + for model in [self, *lower_models.values()]: + model.set_model_params( + input_resolution=input_resolution, batch_handler=batch_handler + ) epochs = list(range(n_epoch)) @@ -444,7 +439,7 @@ def train( t0 = time.time() logger.info( - 'Training model ' 'for {} epochs starting at epoch {}'.format( + 'Training model for {} epochs starting at epoch {}'.format( n_epoch, epochs[0] ) ) diff --git a/sup3r/models/interface.py b/sup3r/models/interface.py index 1387296ecb..f1acdd2356 100644 --- a/sup3r/models/interface.py +++ b/sup3r/models/interface.py @@ -9,13 +9,13 @@ from warnings import warn import numpy as np -from phygnn import CustomNetwork +from phygnn import CustomNetwork from sup3r.preprocessing.data_handlers import ExoData from sup3r.utilities import VERSION_RECORD from sup3r.utilities.utilities import safe_cast -from .utilities import SUP3R_EXO_LAYERS, SUP3R_OBS_LAYERS +from .utilities import SUP3R_LAYERS logger = logging.getLogger(__name__) @@ -79,7 +79,6 @@ def input_dims(self): ------- int """ - # pylint: disable=E1101 if hasattr(self, '_gen'): return self._gen.layers[0].rank if hasattr(self, 'models'): @@ -96,43 +95,20 @@ def is_4d(self): """Check if model expects spatial only input""" return self.input_dims == 4 - # pylint: disable=E1101 - def get_s_enhance_from_layers(self): - """Compute factor by which model will enhance spatial resolution from - layer attributes. Used in model training during high res coarsening""" - s_enhance = None - if hasattr(self, '_gen'): - s_enhancements = [ - getattr(layer, '_spatial_mult', 1) - for layer in self._gen.layers - ] - s_enhance = int(np.prod(s_enhancements)) - return s_enhance - - # pylint: disable=E1101 - def get_t_enhance_from_layers(self): - """Compute factor by which model will enhance temporal resolution from - layer attributes. Used in model training during high res coarsening""" - t_enhance = None - if hasattr(self, '_gen'): - t_enhancements = [ - getattr(layer, '_temporal_mult', 1) - for layer in self._gen.layers - ] - t_enhance = int(np.prod(t_enhancements)) - return t_enhance - @property def s_enhance(self): """Factor by which model will enhance spatial resolution. Used in model training during high res coarsening and also in forward pass routine to determine shape of needed exogenous data""" + + # If there are multiple steps, we want to return the product of the + # enhancement factors models = getattr(self, 'models', [self]) s_enhances = [m.meta.get('s_enhance', None) for m in models] s_enhance = ( - self.get_s_enhance_from_layers() - if any(s is None for s in s_enhances) - else int(np.prod(s_enhances)) + None + if any(enh is None for enh in s_enhances) + else np.prod(s_enhances) ) if len(models) == 1 and isinstance(self.meta, dict): self.meta['s_enhance'] = s_enhance @@ -143,12 +119,15 @@ def t_enhance(self): """Factor by which model will enhance temporal resolution. Used in model training during high res coarsening and also in forward pass routine to determine shape of needed exogenous data""" + + # If there are multiple steps, we want to return the product of the + # enhancement factors models = getattr(self, 'models', [self]) t_enhances = [m.meta.get('t_enhance', None) for m in models] t_enhance = ( - self.get_t_enhance_from_layers() - if any(t is None for t in t_enhances) - else int(np.prod(t_enhances)) + None + if any(enh is None for enh in t_enhances) + else np.prod(t_enhances) ) if len(models) == 1 and isinstance(self.meta, dict): self.meta['t_enhance'] = t_enhance @@ -218,26 +197,29 @@ def _ensure_valid_input_resolution(self): logger.error(msg) raise RuntimeError(msg) - def _ensure_valid_enhancement_factors(self): - """Ensure user provided enhancement factors are the same as those - computed from layer attributes""" - t_enhance = self.meta.get('t_enhance', None) - s_enhance = self.meta.get('s_enhance', None) - if s_enhance is None or t_enhance is None: - return - - layer_se = self.get_s_enhance_from_layers() - layer_te = self.get_t_enhance_from_layers() - layer_se = layer_se if layer_se is not None else self.meta['s_enhance'] - layer_te = layer_te if layer_te is not None else self.meta['t_enhance'] - msg = ( - f'Enhancement factors computed from layer attributes ' - f'(s_enhance={layer_se}, t_enhance={layer_te}) ' - f'conflict with user provided values (s_enhance={s_enhance}, ' - f't_enhance={t_enhance})' - ) - check = layer_se == s_enhance or layer_te == t_enhance - if not check: + def _ensure_feature_consistency(self): + """Ensure that the exogenous features specified in meta are consistent + with the features found in the model layers""" + features = [] + if hasattr(self, '_gen'): + for layer in self._gen.layers: + if isinstance(layer, SUP3R_LAYERS): + feats = ( + [layer.name] + if not hasattr(layer, 'features') + else layer.features + ) + exo_feats = ( + [] + if not hasattr(layer, 'exo_features') + else layer.exo_features + ) + features.extend(feats + exo_feats) + if set(self.hr_exo_features) != set(features): + msg = ( + f'Specified hr_exo_features {self.hr_exo_features} does not ' + f'match features {features} found in model layers.' + ) logger.error(msg) raise RuntimeError(msg) @@ -374,40 +356,24 @@ def hr_out_features(self): generative model outputs.""" return self.meta.get('hr_out_features', []) + @property + def hr_out_features_ind(self): + """Get the indices of the high-resolution output features in the order + they are output by the model.""" + return [self.hr_features.index(feat) for feat in self.hr_out_features] + @property def obs_features(self): """Get list of exogenous observation feature names the model uses. These come from the names of the ``Sup3rObs..`` layers.""" - # pylint: disable=E1101 - features = [] - if hasattr(self, '_gen'): - for layer in self._gen.layers: - if isinstance(layer, SUP3R_OBS_LAYERS): - obs_feats = getattr(layer, 'features', [layer.name]) - obs_feats = [f for f in obs_feats if f not in features] - features.extend(obs_feats) - return features + default = [f for f in self.hr_features if '_obs' in f] + return self.meta.get('obs_features', default) @property def hr_exo_features(self): - """Get list of high-resolution exogenous filter names the model uses. - If the model has N concat or add layers this list will be the last N - features in the training features list. The ordering is assumed to be - the same as the order of concat or add layers. If training features is - [..., topo, sza], and the model has 2 concat or add layers, exo - features will be [topo, sza]. Topo will then be used in the first - concat layer and sza will be used in the second""" - # pylint: disable=E1101 - features = [] - if hasattr(self, '_gen'): - features = [ - layer.name - for layer in self._gen.layers - if isinstance(layer, SUP3R_EXO_LAYERS) - ] - obs_feats = [feat.replace('_obs', '') for feat in self.obs_features] - features += [f for f in obs_feats if f not in self.hr_out_features] - return features + """Get list of gapless exogenous high-resolution feature names the + model uses, like topography.""" + return self.meta.get('hr_exo_features', []) @property def hr_features(self): @@ -415,7 +381,11 @@ def hr_features(self): the high-resolution data during training. This includes both output and exogenous features. """ - return self.hr_out_features + self.hr_exo_features + out = [ + f for f in self.hr_out_features if f not in self.hr_exo_features + ] + out += self.hr_exo_features + return out @property def smoothing(self): @@ -450,53 +420,57 @@ def version_record(self): """ return VERSION_RECORD - def set_model_params(self, **kwargs): + def set_model_params(self, batch_handler=None, **kwargs): """Set parameters used for training the model Parameters ---------- + batch_handler : object | None + Object that contains attributes used to set model meta parameters. + This is used during training to set parameters that are needed for + generation and also to check for consistency with previously set + parameters when loading a model from disk. kwargs : dict - Keyword arguments including 'input_resolution', - 'lr_features', 'hr_exo_features', 'hr_out_features', - 'smoothed_features', 's_enhance', 't_enhance', 'smoothing' + Keyword arguments that are used to set model meta parameters. This + is used during training to set parameters that are needed for + generation and also to check for consistency with previously set + parameters when loading a model from disk. """ - keys = ( - 'input_resolution', + bh_keys = [ + 'smoothing', 'lr_features', 'hr_exo_features', 'hr_out_features', + 'obs_features', 'smoothed_features', 's_enhance', 't_enhance', - 'smoothing', - ) - keys = [k for k in keys if k in kwargs] - if 'hr_out_features' in kwargs: - self.meta['hr_out_features'] = kwargs['hr_out_features'] - - hr_exo_feat = kwargs.get('hr_exo_features', []) - msg = ( - f'Expected high-res exo features {self.hr_exo_features} ' - f'based on model architecture but received "hr_exo_features" ' - f'from data handler: {hr_exo_feat}' - ) - assert list(self.hr_exo_features) == list(hr_exo_feat), msg - - for var in keys: - val = self.meta.get(var, None) + ] + + meta_params = {} + if batch_handler is not None: + meta_params.update({ + k: getattr(batch_handler, k, None) + for k in bh_keys + if hasattr(batch_handler, k) + }) + + meta_params.update(kwargs) + for key, var in meta_params.items(): + val = self.meta.get(key, None) if val is None: - self.meta[var] = kwargs[var] - elif val != kwargs[var]: + self.meta[key] = var + elif val != var: msg = ( - 'Model was previously trained with {var}={} but ' - 'received new {var}={}'.format(val, kwargs[var], var=var) + 'Model was previously trained with {key}={} but ' + 'received new {key}={}'.format(val, var, key=key) ) logger.warning(msg) warn(msg) - self._ensure_valid_enhancement_factors() self._ensure_valid_input_resolution() + self._ensure_feature_consistency() def save_params(self, out_dir): """ diff --git a/sup3r/models/solar_cc.py b/sup3r/models/solar_cc.py index db36fd86bb..ebcb9a808f 100644 --- a/sup3r/models/solar_cc.py +++ b/sup3r/models/solar_cc.py @@ -63,7 +63,7 @@ def __init__(self, *args, t_enhance=None, **kwargs): self._t_enhance = t_enhance or self.t_enhance self.meta['t_enhance'] = self._t_enhance - def init_weights(self, lr_shape, hr_shape, device=None): + def init_weights(self, lr_shape, hr_shape, train_disc=True, device=None): """Initialize the generator and discriminator weights with device placement. @@ -77,6 +77,9 @@ def init_weights(self, lr_shape, hr_shape, device=None): Shape of one batch of high res input data for sup3r resolution. Note that the batch size (axis=0) must be included, but the actual batch size doesn't really matter. + train_disc : bool + Whether to initialize the discriminator weights. If False, only the + generator weights will be initialized. device : str | None Option to place model weights on a device. If None, self.default_device will be used. @@ -87,7 +90,9 @@ def init_weights(self, lr_shape, hr_shape, device=None): if hr_shape[3] != self.DAYLIGHT_HOURS: hr_shape = hr_shape[0:3] + (self.DAYLIGHT_HOURS,) + hr_shape[-1:] - super().init_weights(lr_shape, hr_shape, device=device) + super().init_weights( + lr_shape, hr_shape, train_disc=train_disc, device=device + ) @tf.function def calc_loss( @@ -97,7 +102,7 @@ def calc_loss( weight_gen_advers=0.001, train_gen=True, train_disc=False, - compute_disc=False + compute_disc=False, ): """Calculate the GAN loss function using generated and true high resolution data. diff --git a/sup3r/models/utilities.py b/sup3r/models/utilities.py index f506de2834..05a1da807f 100644 --- a/sup3r/models/utilities.py +++ b/sup3r/models/utilities.py @@ -11,6 +11,7 @@ Sup3rAdder, Sup3rConcat, Sup3rConcatObs, + Sup3rCrossAttention, Sup3rObsModel, ) from scipy.interpolate import RegularGridInterpolator @@ -20,11 +21,13 @@ logger = logging.getLogger(__name__) -SUP3R_OBS_LAYERS = Sup3rObsModel, Sup3rConcatObs - -SUP3R_EXO_LAYERS = Sup3rAdder, Sup3rConcat - -SUP3R_LAYERS = (*SUP3R_EXO_LAYERS, *SUP3R_OBS_LAYERS) +SUP3R_LAYERS = ( + Sup3rObsModel, + Sup3rConcatObs, + Sup3rAdder, + Sup3rConcat, + Sup3rCrossAttention, +) class TrainingSession: diff --git a/sup3r/models/with_obs.py b/sup3r/models/with_obs.py deleted file mode 100644 index 25ea107e94..0000000000 --- a/sup3r/models/with_obs.py +++ /dev/null @@ -1,299 +0,0 @@ -"""Sup3r model with training on observation data.""" - -import logging - -import numpy as np -import tensorflow as tf - -from sup3r.utilities.utilities import RANDOM_GENERATOR - -from .base import Sup3rGan - -logger = logging.getLogger(__name__) - - -class Sup3rGanWithObs(Sup3rGan): - """Sup3r GAN model which includes mid network observation fusion. This - model is useful for when production runs will be over a domain for which - observation data is available. - - Note - ---- - During training this model uses sparse sampling of ground truth data to - simulate observation data. This is done by creating masks of ground truth - data and then selecting unmasked data. The features used for sampling are - defined in the model configuration, in observation specific model layers - such as ``Sup3rObsModel`` or ``Sup3rConcatObs``. The observation features - in the model configuration should be named with an `_obs` suffix, and - during training these are matched with the corresponding true high res - features without the `_obs` suffix. e.g. If the goal is to condition a - model on sparse temperature_10m data then the observation feature should be - named temperature_10m_obs, and during training the gridded high res - temperature_10m data will be masked to create synthetic observation data - for this feature. The observation masks are only used during training. - During inference "real" observation data, with the full name defined in the - model configuration, is passed in as exogenous data with NaN values for - where the observations are not available. These NaN values are then handled - by observation specific model layers. - """ - - def __init__( - self, - *args, - onshore_obs_frac=None, - offshore_obs_frac=None, - loss_obs_weight=0.0, - loss_obs=None, - **kwargs, - ): - """ - Initialize the Sup3rGanWithObs model. - - Parameters - ---------- - args : list - Positional args for ``Sup3rGan`` parent class. - onshore_obs_frac : dict[List] | dict[float] - Fraction of the batch that should be treated as onshore - observations. Should include ``spatial`` key and optionally - ``time`` key if this is a spatiotemporal model. The values should - correspond roughly to the fraction of the production domain for - which onshore observations are available (spatial) and the fraction - of the full time period that these cover. The values can be either - a list (for a lower and upper bound, respectively) or a single - float. For each batch a spatial frac will be selected by either - sampling uniformly between this lower and upper bound or just - using a single float. - offshore_obs_frac : dict - Same as ``onshore_obs_frac`` but for offshore observations. - Offshore observations are frequently sparser than onshore - observations. - loss_obs : str | dict - Loss function to use for the additional observation loss term. This - defaults to the content loss function specified with ``loss``. - loss_obs_weight : float - Value used to weight observation locations in extra content loss - term. e.g. The new content loss will include ``obs_loss_weight * - MAE(hi_res_gen[~obs_mask], hi_res_true[~obs_mask])`` - kwargs : dict - Keyword arguments for the ``Sup3rGan`` parent class. - """ - super().__init__(*args, **kwargs) - self.onshore_obs_frac = ( - {} if onshore_obs_frac is None else onshore_obs_frac - ) - self.offshore_obs_frac = ( - {} if offshore_obs_frac is None else offshore_obs_frac - ) - loss_obs = self.loss_name if loss_obs is None else loss_obs - self.loss_obs_name = loss_obs - self.loss_obs_fun = self.get_loss_fun(loss_obs) - self.loss_obs_weight = loss_obs_weight - - @tf.function - def _get_loss_obs_comparison(self, hi_res_true, hi_res_gen, obs_mask): - """Get loss for observation locations and for non observation - locations.""" - hr_true = hi_res_true[..., : len(self.hr_out_features)] - loss_obs, _ = self.loss_obs_fun( - hi_res_gen[~obs_mask], hr_true[~obs_mask] - ) - loss_non_obs, _ = self.loss_obs_fun( - hi_res_gen[obs_mask], hr_true[obs_mask] - ) - return loss_obs, loss_non_obs - - @property - def obs_training_inds(self): - """Get the indices of the observation features in the true high res - data. Obs features have an _obs suffix to avoid name conflict with - fully gridded features. During training these are matched with the - true high res data.""" - hr_feats = [f.replace('_obs', '') for f in self.hr_features] - obs_inds = [ - hr_feats.index(f.replace('_obs', '')) for f in self.obs_features - ] - return obs_inds - - def _get_single_obs_mask(self, hi_res, spatial_frac, time_frac=1.0): - """Get observation mask for a given spatial and temporal obs - fraction for a single batch entry. - - Parameters - ---------- - hi_res : np.ndarray - True high resolution data for a single batch entry. - spatial_frac : float - Fraction of the spatial domain that should be treated as - observations. This is a value between 0 and 1. - time_frac : float, optional - Fraction of the temporal domain that should be treated as - observations. This is a value between 0 and 1. Default is 1.0 - - Returns - ------- - np.ndarray - Mask which is True for locations that are not observed and False - for locations that are observed. - (spatial_1, spatial_2, n_features) - (spatial_1, spatial_2, n_temporal, n_features) - """ - mask_shape = [*hi_res.shape[:3], 1, len(self.hr_out_features)] - mask_shape[3] = hi_res.shape[3] if self.is_5d else 1 - s_mask = RANDOM_GENERATOR.uniform(size=mask_shape[1:3]) <= spatial_frac - s_mask = s_mask[..., None, None] - t_mask = RANDOM_GENERATOR.uniform(size=mask_shape[-2]) <= time_frac - t_mask = t_mask[None, None, ..., None] - mask = ~(s_mask & t_mask) - mask = np.repeat(mask, mask_shape[-1], axis=-1) - return mask if self.is_5d else np.squeeze(mask, axis=-2) - - def _get_obs_mask(self, hi_res, spatial_frac, time_frac=1.0): - """Get observation mask for a given spatial and temporal obs - fraction for an entire batch. This is divided between spatial and - temporal fractions because often the spatial fraction is significantly - lower than the temporal fraction in practice, e.g. for a given spatial - location there might be observations for most of the time period but - only a small fraction of the spatial domain is observed. - - Parameters - ---------- - hi_res : np.ndarray - True high resolution data for the entire batch. - spatial_frac : float | list - Fraction of the spatial domain that should be treated as - observations. This is a value between 0 and 1 or a list with - lower and upper bounds for the spatial fraction. - time_frac : float | list, optional - Fraction of the temporal domain that should be treated as - observations. This is a value between 0 and 1 or a list with - lower and upper bounds for the temporal fraction. Default is 1.0 - - Returns - ------- - np.ndarray - Mask which is True for locations that are not observed and False - for locations that are observed. - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - """ - s_range = ( - spatial_frac - if isinstance(spatial_frac, (list, tuple)) - else [spatial_frac, spatial_frac] - ) - t_range = ( - time_frac - if isinstance(time_frac, (list, tuple)) - else [time_frac, time_frac] - ) - s_fracs = RANDOM_GENERATOR.uniform(*s_range, size=hi_res.shape[0]) - t_fracs = RANDOM_GENERATOR.uniform(*t_range, size=hi_res.shape[0]) - s_fracs = np.clip(s_fracs, 0, 1) - t_fracs = np.clip(t_fracs, 0, 1) - mask = tf.stack( - [ - self._get_single_obs_mask(hi_res, s, t) - for s, t in zip(s_fracs, t_fracs) - ], - axis=0, - ) - return mask - - def _get_full_obs_mask(self, hi_res): - """Define observation mask for the current batch. This differs from - ``_get_obs_mask`` by defining a composite mask based on separate - onshore and offshore masks. This is because there is often more - observation data available onshore than offshore.""" - on_sf = self.onshore_obs_frac['spatial'] - on_tf = self.onshore_obs_frac.get('time', 1.0) - obs_mask = self._get_obs_mask(hi_res, on_sf, on_tf) - if 'topography' in self.hr_features and self.offshore_obs_frac: - topo_idx = self.hr_features.index('topography') - topo = hi_res[..., topo_idx] - off_sf = self.offshore_obs_frac['spatial'] - off_tf = self.offshore_obs_frac.get('time', 1.0) - offshore_mask = self._get_obs_mask(hi_res, off_sf, off_tf) - obs_mask = tf.where(topo[..., None] > 0, obs_mask, offshore_mask) - return obs_mask - - @property - def model_params(self): - """ - Model parameters, used to save model to disc - - Returns - ------- - dict - """ - params = super().model_params - params['onshore_obs_frac'] = self.onshore_obs_frac - params['offshore_obs_frac'] = self.offshore_obs_frac - params['loss_obs_weight'] = self.loss_obs_weight - params['loss_obs'] = self.loss_obs_name - return params - - @tf.function - def get_hr_exo_input(self, hi_res_true): - """Mask high res data to act as sparse observation data. Add this to - the standard high res exo input""" - exo_data = super().get_hr_exo_input(hi_res_true) - if len(self.obs_features) == 0: - return exo_data - obs_mask = self._get_full_obs_mask(hi_res_true) - nan_const = tf.constant(float('nan'), dtype=hi_res_true.dtype) - obs = tf.gather(hi_res_true, self.obs_training_inds, axis=-1) - obs = tf.where(obs_mask[..., : obs.shape[-1]], nan_const, obs) - obs = tf.expand_dims(obs, axis=-2) - exo_obs = dict(zip(self.obs_features, tf.unstack(obs, axis=-1))) - exo_data.update(exo_obs) - exo_data['mask'] = obs_mask - return exo_data - - def _get_hr_exo_and_loss( - self, - low_res, - hi_res_true, - **calc_loss_kwargs, - ): - """Get high-resolution exogenous data, generate synthetic output, and - compute loss. Includes artificially masking hi res data to act as - sparse observation data.""" - out = super()._get_hr_exo_and_loss( - low_res, hi_res_true, **calc_loss_kwargs - ) - loss, loss_details, hi_res_gen, hi_res_exo = out - - if calc_loss_kwargs.get('train_gen', True): - loss_obs, loss_non_obs = self._get_loss_obs_comparison( - hi_res_true, - hi_res_gen, - hi_res_exo['mask'], - ) - n_obs = tf.reduce_sum(tf.cast(~hi_res_exo['mask'], tf.float32)) - n_total = tf.cast(tf.size(hi_res_exo['mask']), tf.float32) - obs_frac = n_obs / n_total - loss_update = { - 'loss_obs': loss_obs, - 'loss_non_obs': loss_non_obs, - 'obs_frac': obs_frac, - } - if self.loss_obs_weight and obs_frac > 0: - loss_obs *= self.loss_obs_weight - loss += loss_obs - loss_details['loss_gen'] += loss_obs - loss_details['loss_gen_content'] += loss_obs - loss_details.update(loss_update) - return loss, loss_details, hi_res_gen, hi_res_exo - - def _post_batch(self, ib, b_loss_details, n_batches, previous_means): - """Update loss details after the current batch and write to log.""" - if 'obs_frac' in b_loss_details: - logger.debug( - 'Batch {} out of {} has obs_frac: {:.4e}'.format( - ib + 1, n_batches, b_loss_details['obs_frac'] - ) - ) - return super()._post_batch( - ib, b_loss_details, n_batches, previous_means - ) diff --git a/sup3r/pipeline/forward_pass_cli.py b/sup3r/pipeline/forward_pass_cli.py index eaf380dd1c..9700271748 100644 --- a/sup3r/pipeline/forward_pass_cli.py +++ b/sup3r/pipeline/forward_pass_cli.py @@ -60,7 +60,8 @@ def from_config(ctx, config_file, verbose=False, pipeline_step=None): sig = signature(ForwardPassStrategy) strategy_kwargs = {k: v for k, v in config.items() if k in sig.parameters} - strategy = ForwardPassStrategy(**strategy_kwargs, head_node=True) + head_node = strategy_kwargs.get('max_nodes') != 1 + strategy = ForwardPassStrategy(**strategy_kwargs, head_node=head_node) if node_index is not None: nodes = ( diff --git a/sup3r/preprocessing/__init__.py b/sup3r/preprocessing/__init__.py index 5a8dcc6246..ab4e201d7f 100644 --- a/sup3r/preprocessing/__init__.py +++ b/sup3r/preprocessing/__init__.py @@ -51,11 +51,11 @@ from .names import COORD_NAMES, DIM_NAMES, FEATURE_NAMES, Dimension from .rasterizers import ( BaseExoRasterizer, + DerivedFeatureRasterizer, DualRasterizer, ExoRasterizer, ObsRasterizer, Rasterizer, - SzaRasterizer, ) from .samplers import ( DualSampler, diff --git a/sup3r/preprocessing/batch_handlers/factory.py b/sup3r/preprocessing/batch_handlers/factory.py index c22e19992d..d19cc8e337 100644 --- a/sup3r/preprocessing/batch_handlers/factory.py +++ b/sup3r/preprocessing/batch_handlers/factory.py @@ -151,21 +151,31 @@ def __init__( memory right away. feature_sets : Optional[dict] Optional dictionary describing how the full set of features is - split between `lr_only_features` and `hr_exo_features`. + split between ``lr_features``, ``hr_features``, and + ``hr_out_features``. - features : list | tuple - List of full set of features to use for sampling. If no - entry is provided then all data_vars from container data + lr_features : list | tuple + List of feature names or patt*erns to use as low-resolution + model inputs. If no entry is provided then all available + features from the data will be used. + hr_out_features : list | tuple + List of feature names or patt*erns that should be output by + the generative model and available as ground truth targets. + If no entry is provided then all features in lr_features will be used. - lr_only_features : list | tuple - List of feature names or patt*erns that should only be - included in the low-res training set and not the high-res - observations. This hr_exo_features : list | tuple - List of feature names or patt*erns that should be included - in the high-resolution observation but not expected to be - output from the generative model. An example is high-res - topography that is to be injected mid-network. + List of feature names or patt*erns that should be available + as high-resolution model inputs (like topography or + observations). These are injected into the model + mid-network to condition output on high-resolution + information. The model configuration should have the + appropriate layers to use these features. e.g. + ``Sup3rConcat`` for topography injection, ``Sup3rObsModel`` + or ``Sup3rCrossAttention`` for obs injection. If no entry + is provided then hr_exo_features will be empty. + + *To include sparse features as inputs or targets the features + must have an "_obs" suffix. kwargs : dict Additional keyword arguments for BatchQueue and / or Samplers. This can vary depending on the type of BatchQueue / Sampler diff --git a/sup3r/preprocessing/batch_queues/base.py b/sup3r/preprocessing/batch_queues/base.py index ccba5deb09..c45f8c81cf 100644 --- a/sup3r/preprocessing/batch_queues/base.py +++ b/sup3r/preprocessing/batch_queues/base.py @@ -13,21 +13,18 @@ class SingleBatchQueue(AbstractBatchQueue): - """Base BatchQueue class for single dataset containers - - Note - ---- - Here we use `len(self.features)` for the last dimension of samples, since - samples in :class:`SingleBatchQueue` queues are coarsened to produce - low-res samples, and then the `lr_only_features` are removed with - `hr_features_ind`. In contrast, for samples in :class:`DualBatchQueue` - queues there are low / high res pairs and the high-res only stores the - `hr_features`""" + """Base BatchQueue class for single dataset containers""" @property def queue_shape(self): """Shape of objects stored in the queue.""" - return [(self.batch_size, *self.hr_sample_shape, len(self.features))] + return [ + ( + self.batch_size, + *self.hr_sample_shape, + len(self.hr_source_features), + ) + ] def transform( self, @@ -69,7 +66,8 @@ def transform( (batch_size, spatial_1, spatial_2, features) (batch_size, spatial_1, spatial_2, temporal, features) """ - low_res = spatial_coarsening(samples, self.s_enhance) + lr_samples = numpy_if_tensor(samples)[..., self.lr_features_ind] + low_res = spatial_coarsening(lr_samples, self.s_enhance) low_res = ( low_res if self.t_enhance == 1 @@ -81,7 +79,7 @@ def transform( smoothing_ignore if smoothing_ignore is not None else [] ) low_res = smooth_data( - low_res, self.features, smoothing_ignore, smoothing + low_res, self.lr_features, smoothing_ignore, smoothing ) high_res = numpy_if_tensor(samples)[..., self.hr_features_ind] return low_res, high_res diff --git a/sup3r/preprocessing/batch_queues/dual.py b/sup3r/preprocessing/batch_queues/dual.py index 56b2b08d42..865e6ff745 100644 --- a/sup3r/preprocessing/batch_queues/dual.py +++ b/sup3r/preprocessing/batch_queues/dual.py @@ -20,7 +20,6 @@ def __init__(self, samplers, **kwargs): -------- :class:`~sup3r.preprocessing.batch_queues.abstract.AbstractBatchQueue` """ - self.BATCH_MEMBERS = samplers[0].dset_names super().__init__(samplers, **kwargs) self.check_enhancement_factors() @@ -31,16 +30,10 @@ def queue_shape(self): """Shape of objects stored in the queue. Optionally includes shape of observation data which would be included in an extra content loss term""" - obs_shape = ( - *self.hr_shape[:-1], - len(self.containers[0].hr_out_features), - ) - queue_shapes = [ + return [ (self.batch_size, *self.lr_shape), (self.batch_size, *self.hr_shape), - (self.batch_size, *obs_shape), ] - return queue_shapes[: len(self.BATCH_MEMBERS)] def check_enhancement_factors(self): """Make sure each DualSampler has the same enhancment factors and they diff --git a/sup3r/preprocessing/collections/stats.py b/sup3r/preprocessing/collections/stats.py index 665994547c..12b0233304 100644 --- a/sup3r/preprocessing/collections/stats.py +++ b/sup3r/preprocessing/collections/stats.py @@ -62,7 +62,16 @@ def _get_stat(self, stat_type, needed_features='all'): getattr(c.high_res[hr_feats], stat_type)(skipna=True) for c in self.containers ] - if any(lr_feats): + lr_check = any(lr_feats) + container_check = all(hasattr(c, 'low_res') for c in self.containers) + if lr_check and not container_check: + msg = ( + f'Found low-res features {lr_feats} but not all containers ' + 'have low-res data. ' + ) + logger.error(msg) + raise ValueError(msg) + elif lr_check and container_check: cstats_lr = [ getattr(c.low_res[lr_feats], stat_type)(skipna=True) for c in self.containers @@ -92,9 +101,9 @@ def _init_stats_dict(self, stats): and any(f not in stats for f in self.features) ): msg = ( - f'Not all features ({self.features}) are found in the given ' - f'stats dictionary {stats}. This is obviously from a prior ' - 'run so you better be sure these stats carry over.' + f'Not all features ({self.features}) are found in the ' + f'given stats dictionary {stats}. This is obviously from a ' + 'prior run so you better be sure these stats carry over.' ) logger.warning(msg) warn(msg) @@ -116,7 +125,7 @@ def get_means(self, means): ] for f in needed_features: logger.info(f'Computing mean for {f}.') - means[f] = np.float32(np.sum([cm[f] for cm in cmeans])) + means[f] = np.float32(np.nansum([cm[f] for cm in cmeans])) return means def get_stds(self, stds): @@ -132,7 +141,9 @@ def get_stds(self, stds): ] for f in needed_features: logger.info(f'Computing std for {f}.') - stds[f] = np.float32(np.sqrt(np.sum([cs[f] for cs in cstds]))) + stds[f] = np.float32( + np.sqrt(np.nansum([cs[f] for cs in cstds])) + ) return stds @staticmethod diff --git a/sup3r/preprocessing/data_handlers/exo.py b/sup3r/preprocessing/data_handlers/exo.py index 4d337663fb..e1aed21fc6 100644 --- a/sup3r/preprocessing/data_handlers/exo.py +++ b/sup3r/preprocessing/data_handlers/exo.py @@ -302,11 +302,10 @@ def __init__( feature : str Exogenous feature to extract from file_paths file_paths : str | list - A single source h5 file or netcdf file to extract raster - data from. The string can be a unix-style file path which - will be passed through glob.glob. This is typically low-res - WRF output or GCM netcdf data that is source low-resolution - data intended to be sup3r resolved. + Filepaths(s) used to define the grid that the high-resolution + exogenous data will be mapped onto. This can be either a single h5 + file or a list of netcdf files with identical grid. The string can + be a unix-style file path which will be passed through glob.glob. model : Sup3rGan | MultiStepGan Model used to get exogenous data. If a ``MultiStepGan`` ``lr_features``, ``hr_exo_features``, and @@ -331,7 +330,9 @@ def __init__( exo_rasterizer_kwargs : dict Keyword arguments passed to the :class:`~sup3r.preprocessing.rasterizers.exo.BaseExoRasterizer` - class. + class. This is used to specify parameters for exogenous data + rasterization such as the ``source_files`` for the exogenous data + and the method of rasterization. """ self.feature = feature self.file_paths = file_paths @@ -465,7 +466,7 @@ def _get_single_step_enhance(self, step): 'Received exo_kwargs entry without valid combine_type ' '(input/layer/output)' ) - assert combine_type.lower() in ('input', 'output', 'layer'), msg + assert combine_type.lower() in {'input', 'output', 'layer'}, msg if combine_type.lower() == 'input': if mstep == 0: s_enhance = 1 diff --git a/sup3r/preprocessing/derivers/methods.py b/sup3r/preprocessing/derivers/methods.py index 030bc70269..16ecd696fb 100644 --- a/sup3r/preprocessing/derivers/methods.py +++ b/sup3r/preprocessing/derivers/methods.py @@ -134,7 +134,7 @@ class CloudMask(DerivedFeature): """Cloud Mask feature class. Inputs here are typically found in H5 data like the NSRDB.""" - inputs = ('ghi', 'clearky_ghi') + inputs = ('ghi', 'clearsky_ghi') @classmethod def compute(cls, data): @@ -407,7 +407,8 @@ def compute(cls, data): """Compute method for latitude.""" lat = data[Dimension.LATITUDE] lat = lat.expand_dims(Dimension.TIME, axis=-1) - lat = np.repeat(lat, len(data.time_index), axis=-1) + n_time = 1 if data.time_index is None else len(data.time_index) + lat = np.repeat(lat, n_time, axis=-1) return lat.astype(np.float32) @@ -419,10 +420,27 @@ def compute(cls, data): """Compute method for longitude.""" lon = data[Dimension.LONGITUDE] lon = lon.expand_dims(Dimension.TIME, axis=-1) - lon = np.repeat(lon, len(data.time_index), axis=-1) + n_time = 1 if data.time_index is None else len(data.time_index) + lon = np.repeat(lon, n_time, axis=-1) return lon.astype(np.float32) +class Time(DerivedFeature): + """Time feature with latitude and longitude dimensions included.""" + + @classmethod + def compute(cls, data): + """Compute method for time.""" + time = data[Dimension.TIME].astype('datetime64[s]').astype(np.int64) + # Expand along the 2D spatial dimensions, then explicitly repeat along + # each dimension using its size to handle non-square grids correctly. + spatial_dims = Dimension.dims_2d() + time = time.expand_dims(spatial_dims, axis=(0, 1)) + time = np.repeat(time, data.sizes[spatial_dims[0]], axis=0) + time = np.repeat(time, data.sizes[spatial_dims[1]], axis=1) + return time.astype(np.float64) + + class SpatioTemporalEncoding(DerivedFeature): """General positional or temporal encoding. @@ -533,6 +551,7 @@ def compute(cls, data, i=1): 'sza': Sza, 'latitude_feature': Latitude, 'longitude_feature': Longitude, + 'time_feature': Time, 'soy_encoding': SecondOfYearEncoding, 'sod_encoding': SecondOfDayEncoding, 'lat_encoding': LatitudeEncoding, diff --git a/sup3r/preprocessing/rasterizers/__init__.py b/sup3r/preprocessing/rasterizers/__init__.py index f0b0dc8f34..76bccb1029 100644 --- a/sup3r/preprocessing/rasterizers/__init__.py +++ b/sup3r/preprocessing/rasterizers/__init__.py @@ -8,5 +8,10 @@ from .base import BaseRasterizer from .dual import DualRasterizer -from .exo import BaseExoRasterizer, ExoRasterizer, ObsRasterizer, SzaRasterizer +from .exo import ( + BaseExoRasterizer, + DerivedFeatureRasterizer, + ExoRasterizer, + ObsRasterizer, +) from .extended import Rasterizer diff --git a/sup3r/preprocessing/rasterizers/dual.py b/sup3r/preprocessing/rasterizers/dual.py index 981cfec09f..66e1a9f6f3 100644 --- a/sup3r/preprocessing/rasterizers/dual.py +++ b/sup3r/preprocessing/rasterizers/dual.py @@ -92,11 +92,11 @@ def __init__( f'Sup3rDataset instance. Received {type(data)}.' ) assert isinstance(data, Sup3rDataset), msg - self.lr_data, self.hr_data = data.low_res, data.high_res + self.data = data self.regrid_workers = regrid_workers - lr_step = self.lr_data.time_step - hr_step = self.hr_data.time_step + lr_step = self.data.low_res.time_step + hr_step = self.data.high_res.time_step msg = ( f'Time steps of high-res data ({hr_step} seconds) and low-res ' f'data ({lr_step} seconds) are inconsistent with t_enhance = ' @@ -105,9 +105,9 @@ def __init__( assert np.allclose(lr_step, hr_step * self.t_enhance), msg self.lr_required_shape = ( - self.hr_data.shape[0] // self.s_enhance, - self.hr_data.shape[1] // self.s_enhance, - self.hr_data.shape[2] // self.t_enhance, + self.data.high_res.shape[0] // self.s_enhance, + self.data.high_res.shape[1] // self.s_enhance, + self.data.high_res.shape[2] // self.t_enhance, ) self.hr_required_shape = ( self.s_enhance * self.lr_required_shape[0], @@ -118,16 +118,16 @@ def __init__( msg = ( f'The required low-res shape {self.lr_required_shape} is ' 'inconsistent with the shape of the raw data ' - f'{self.lr_data.shape}' + f'{self.data.low_res.shape}' ) assert all( req_s <= true_s for req_s, true_s in zip( - self.lr_required_shape, self.lr_data.shape + self.lr_required_shape, self.data.low_res.shape ) ), msg - self.hr_lat_lon = self.hr_data.lat_lon[ + self.hr_lat_lon = self.data.high_res.lat_lon[ slice(self.hr_required_shape[0]), slice(self.hr_required_shape[1]) ] self.lr_lat_lon = spatial_coarsening( @@ -137,49 +137,51 @@ def __init__( self.update_lr_data() self.update_hr_data() - super().__init__(data=(self.lr_data, self.hr_data)) + super().__init__(data=(self.data.low_res, self.data.high_res)) if run_qa: self.check_regridded_lr_data() if lr_cache_kwargs is not None: - Cacher(self.lr_data, lr_cache_kwargs) + Cacher(self.data.low_res, lr_cache_kwargs) if hr_cache_kwargs is not None: - Cacher(self.hr_data, hr_cache_kwargs) + Cacher(self.data.high_res, hr_cache_kwargs) def update_hr_data(self): """Set the high resolution data attribute and check if hr_data.shape is divisible by s_enhance. If not, take the largest shape that can be.""" msg = ( - f'hr_data.shape: {self.hr_data.shape[:3]} is not ' + f'hr_data.shape: {self.data.high_res.shape[:3]} is not ' f'divisible by s_enhance: {self.s_enhance}. Using shape: ' f'{self.hr_required_shape} instead.' ) - need_new_shape = self.hr_data.shape[:3] != self.hr_required_shape[:3] + need_new_shape = ( + self.data.high_res.shape[:3] != self.hr_required_shape[:3] + ) if need_new_shape: logger.warning(msg) warn(msg) hr_data_new = {} - for f in self.hr_data.features: + for f in self.data.high_res.features: hr_slices = [slice(sh) for sh in self.hr_required_shape] - hr = self.hr_data.to_dataarray().sel(variable=f).data + hr = self.data.high_res.to_dataarray().sel(variable=f).data hr_data_new[f] = hr[tuple(hr_slices)] hr_coords_new = { Dimension.LATITUDE: self.hr_lat_lon[..., 0], Dimension.LONGITUDE: self.hr_lat_lon[..., 1], - Dimension.TIME: self.hr_data.indexes['time'][ + Dimension.TIME: self.data.high_res.indexes['time'][ : self.hr_required_shape[2] ], } logger.info( - 'Updating self.hr_data with new shape: ' + 'Updating self.data.high_res with new shape: ' f'{self.hr_required_shape[:3]}' ) - self.hr_data = self.hr_data.update_ds({ + self.data.high_res = self.data.high_res.update_ds({ **hr_coords_new, **hr_data_new, }) @@ -191,7 +193,9 @@ def get_regridder(self): data=self.lr_lat_lon.reshape((-1, 2)), ) return Regridder( - self.lr_data.meta, target_meta, max_workers=self.regrid_workers + self.data.low_res.meta, + target_meta, + max_workers=self.regrid_workers, ) def update_lr_data(self): @@ -203,20 +207,20 @@ def update_lr_data(self): regridder = self.get_regridder() lr_data_new = {} - for f in self.lr_data.features: - lr = self.lr_data.to_dataarray().sel(variable=f).data + for f in self.data.low_res.features: + lr = self.data.low_res.to_dataarray().sel(variable=f).data lr = lr[..., : self.lr_required_shape[2]] lr_data_new[f] = regridder(lr).reshape(self.lr_required_shape) lr_coords_new = { Dimension.LATITUDE: self.lr_lat_lon[..., 0], Dimension.LONGITUDE: self.lr_lat_lon[..., 1], - Dimension.TIME: self.lr_data.indexes[Dimension.TIME][ + Dimension.TIME: self.data.low_res.indexes[Dimension.TIME][ : self.lr_required_shape[2] ], } - logger.info('Updating self.lr_data with regridded data.') - self.lr_data = self.lr_data.update_ds({ + logger.info('Updating self.data.low_res with regridded data.') + self.data.low_res = self.data.low_res.update_ds({ **lr_coords_new, **lr_data_new, }) @@ -225,8 +229,8 @@ def check_regridded_lr_data(self): """Check for NaNs after regridding and do NN fill if needed.""" fill_feats = [] logger.info('Checking for NaNs after regridding') - qa_info = self.lr_data.qa(stats=['nan_perc']) - for f in self.lr_data.features: + qa_info = self.data.low_res.qa(stats=['nan_perc']) + for f in self.data.low_res.features: nan_perc = qa_info[f]['nan_perc'] if nan_perc > 0: msg = f'{f} data has {nan_perc:.3f}% NaN values!' @@ -244,6 +248,6 @@ def check_regridded_lr_data(self): f'features = {fill_feats}' ) logger.info(msg) - self.lr_data = self.lr_data.interpolate_na( + self.data.low_res = self.data.low_res.interpolate_na( features=fill_feats, method='nearest' ) diff --git a/sup3r/preprocessing/rasterizers/exo.py b/sup3r/preprocessing/rasterizers/exo.py index 63fb80e5a4..5f5838425c 100644 --- a/sup3r/preprocessing/rasterizers/exo.py +++ b/sup3r/preprocessing/rasterizers/exo.py @@ -19,7 +19,7 @@ from sup3r.preprocessing.accessor import Sup3rX from sup3r.preprocessing.base import Sup3rMeta -from sup3r.preprocessing.derivers.utilities import SolarZenith +from sup3r.preprocessing.derivers.methods import RegistryBase from sup3r.preprocessing.loaders import Loader from sup3r.preprocessing.names import Dimension from sup3r.preprocessing.utilities import compute_if_dask @@ -35,16 +35,14 @@ class BaseExoRasterizer(ABC): """Class to extract high-res (4km+) data rasters for new spatially-enhanced datasets (e.g. GCM files after spatial enhancement) using nearest neighbor - mapping and aggregation from high-res datasets (e.g. WTK or NSRDB) + mapping and aggregation from high-res datasets (e.g. WTK or NSRDB). Parameters ---------- feature : str Name of exogenous feature to rasterize. file_paths : str | list - Filepaths(s) for typically low-res WRF output or GCM netcdf data files - that is source low-resolution data intended to be sup3r resolved. - These are used to define the grid that the high-resolution exogenous + Filepaths(s) used to define the grid that the high-resolution exogenous data will be mapped onto. This can be either a single h5 file or a list of netcdf files with identical grid. The string can be a unix-style file path which will be passed through glob.glob. @@ -128,7 +126,12 @@ class will output a topography raster corresponding to the file_paths # These sometimes have a time dimension but we don't need the time in # the cache file - STATIC_FEATURES: ClassVar = ('topography', 'srl') + STATIC_FEATURES: ClassVar = ( + 'topography', + 'srl', + 'latitude_feature', + 'longitude_feature', + ) @log_args def __post_init__(self): @@ -190,7 +193,7 @@ def cache_file(self): fn += f'{"x".join(map(str, self.input_handler.grid_shape))}_' # add the time index to the filename if data is time dependent - if self.source_data.shape[-1] > 1: + if self.hr_shape[-1] > 1 and self.feature not in self.STATIC_FEATURES: start = str(self.hr_time_index[0]) start = start.replace(':', '').replace('-', '').replace(' ', '') end = str(self.hr_time_index[-1]) @@ -211,8 +214,8 @@ def coords(self): coord: (Dimension.dims_2d(), self.hr_lat_lon[..., i]) for i, coord in enumerate(Dimension.coords_2d()) } - if self.source_data.shape[1] > 1: - coords['time'] = self.hr_time_index + if self.hr_shape[-1] > 1 and self.feature not in self.STATIC_FEATURES: + coords[Dimension.TIME] = self.hr_time_index return coords @property @@ -340,6 +343,7 @@ def data(self): data = Loader(cache_fp) else: data = self.get_data() + logger.info(f'Finished rasterizing "{self.feature}"') if cache_fp is not None and not os.path.exists(cache_fp): Cacher._write_single( @@ -386,10 +390,6 @@ def get_data(self): hr_data *= self.scale_factor hr_data = self._check_coverage(hr_data) - logger.info( - f'Finished mapping raster from {self.source_files} for ' - f'"{self.feature}"', - ) data_vars = (dims, da.asarray(hr_data, dtype=np.float32)) data_vars = {self.feature: data_vars} return Sup3rX(xr.Dataset(coords=self.coords, data_vars=data_vars)) @@ -488,23 +488,36 @@ class ObsRasterizer(BaseExoRasterizer): @property def source_handler(self): - """Get the Loader object that handles the exogenous data file. This - assumes the feature name does not have the '_obs' suffix which is used - to trigger this rasterizer.""" - feat = self.feature.replace('_obs', '') + """Get the Loader object that handles the exogenous data file.""" if self._source_handler is None: - self._source_handler = Loader( - self.source_files, - features=[feat], - **self.source_handler_kwargs, - ) + try: + self._source_handler = Loader( + self.source_files, + features=[self.feature], + **self.source_handler_kwargs, + ) + except KeyError: + msg = ( + f'{self.feature} not found in {self.source_files}. ' + f'Will check {self.feature.replace("_obs", "")}.' + ) + logger.warning(msg) + self._source_handler = Loader( + self.source_files, + features=[self.feature.replace('_obs', '')], + **self.source_handler_kwargs, + ) return self._source_handler @property def source_data(self): """Get the flattened observation data from the source_files""" if self._source_data is None: - feat = self.feature.replace('_obs', '') + feat = ( + self.feature + if self.feature in self.source_handler + else self.feature.replace('_obs', '') + ) src = self.source_handler[feat].data self._source_data = src.reshape((-1, src.shape[-1])) return self._source_data @@ -533,23 +546,24 @@ def _check_coverage(self, hr_data): return hr_data -class SzaRasterizer(BaseExoRasterizer): - """SzaRasterizer for H5 files""" +class DerivedFeatureRasterizer(BaseExoRasterizer): + """Rasterizer for features that can be derived from lat/lon and time with + a method in `RegistryBase`. For example, features like sza that are + computed from the lat/lon and time of the high-resolution grid.""" @property def source_data(self): - """Get the 1D array of sza data from the source_file_h5""" - return SolarZenith.get_zenith(self.hr_time_index, self.hr_lat_lon) + """Derive the source data from the lat/lon and time of the + high-resolution grid.""" + if self._source_data is None: + ds = Sup3rX(xr.Dataset(coords=self.coords)) + ds[self.feature] = RegistryBase[self.feature.lower()].compute(ds) + self._source_data = ds + return self._source_data def get_data(self): - """Get a raster of source values corresponding to the high-res grid - (the file_paths input grid * s_enhance * t_enhance). The shape is - (lats, lons, temporal) - """ - logger.info(f'Finished computing {self.feature} data') - data_vars = {self.feature: (Dimension.dims_3d(), self.source_data)} - ds = xr.Dataset(coords=self.coords, data_vars=data_vars) - return Sup3rX(ds) + """Pass through for `source_data` to override base class method.""" + return self.source_data class ExoRasterizer(BaseExoRasterizer, metaclass=Sup3rMeta): @@ -557,9 +571,34 @@ class ExoRasterizer(BaseExoRasterizer, metaclass=Sup3rMeta): def __new__(cls, feature, file_paths, source_files=None, **kwargs): """Override parent class to return type specific class based on - `source_files`""" - if feature.lower() == 'sza': - ExoClass = SzaRasterizer + `source_files` + + Parameters + ---------- + feature : str + Name of exogenous feature to rasterize. If the feature name ends + with '_obs' then the `ObsRasterizer` will be used which is designed + for sparse spatiotemporal observation data. If the feature can be + derived from just grid variables then the + `DerivedFeatureRasterizer` will be used. Otherwise, the + `BaseExoRasterizer` will be used which is designed for more dense + spatiotemporal data like topography or srl. + file_paths : str | list + Filepaths(s) used to define the grid that the high-resolution + exogenous data will be mapped onto. This can be either a single h5 + file or a list of netcdf files with identical grid. The string can + be a unix-style file path which will be passed through glob.glob. + source_files : str | list | None + Filepath(s) to hi-res exogenous data, which will be mapped to the + enhanced grid of the file_paths input. + """ + if feature.lower() in { + 'sza', + 'latitude_feature', + 'longitude_feature', + 'time_feature', + }: + ExoClass = DerivedFeatureRasterizer elif feature.lower().endswith('_obs'): ExoClass = ObsRasterizer else: @@ -571,7 +610,6 @@ def __new__(cls, feature, file_paths, source_files=None, **kwargs): 'feature': feature, **kwargs, } - logger.info( f'Using {ExoClass.__name__} to rasterize feature "{feature}"' ) diff --git a/sup3r/preprocessing/samplers/base.py b/sup3r/preprocessing/samplers/base.py index b6be139d71..f7e69f428a 100644 --- a/sup3r/preprocessing/samplers/base.py +++ b/sup3r/preprocessing/samplers/base.py @@ -15,6 +15,7 @@ uniform_time_sampler, ) from sup3r.preprocessing.utilities import compute_if_dask, lowered +from sup3r.utilities.utilities import RANDOM_GENERATOR logger = logging.getLogger(__name__) @@ -29,6 +30,7 @@ def __init__( sample_shape: Optional[tuple] = None, batch_size: int = 16, feature_sets: Optional[dict] = None, + proxy_obs_kwargs: Optional[dict] = None, mode: str = 'lazy', ): """ @@ -52,20 +54,46 @@ def __init__( samples and then stacking. feature_sets : Optional[dict] Optional dictionary describing how the full set of features is - split between ``lr_only_features`` and ``hr_exo_features``. - - features : list | tuple - List of full set of features to use for sampling. If no entry - is provided then all data_vars from data will be used. - lr_only_features : list | tuple - List of feature names or patt*erns that should only be - included in the low-res training set and not the high-res - observations. + split between ``lr_features``, ``hr_exo_features``, and + ``hr_out_features``. + + lr_features : list | tuple + List of feature names or patt*erns to use as low-resolution + model inputs. If no entry is provided then all available + features from the data will be used. + hr_out_features : list | tuple + List of feature names or patt*erns that should be output + by the generative model and available as ground truth targets. + If no entry is provided then all features in lr_features will + be used. hr_exo_features : list | tuple - List of feature names or patt*erns that should be included - in the high-resolution observation but not expected to be - output from the generative model. An example is high-res - topography that is to be injected mid-network. + List of feature names or patt*erns that should be available + as high-resolution model inputs (like topography or + observations). These are injected into the model mid-network + to condition output on high-resolution information. The model + configuration should have the appropriate layers to use these + features. e.g. ``Sup3rConcat`` for topography injection, + ``Sup3rObsModel`` or ``Sup3rCrossAttention`` for obs injection. + If no entry is provided then hr_exo_features will be empty. + + *To include sparse features as inputs or targets the features + must have an "_obs" suffix. + proxy_obs_kwargs : dict | None + Optional dictionary of keyword arguments to pass to the proxy + observation generator. This is only used when training with proxy + observations. Keys can include ``onshore_obs_frac`` and + ``offshore_obs_frac`` which specify the fraction of the batch that + should be treated as onshore and offshore observations, + respectively. For example, ``proxy_obs_kwargs={'onshore_obs_frac': + {'spatial': 0.1, 'temporal': 0.2}, 'offshore_obs_frac': {'spatial': + 0.05, 'time': 0.1}}`` would specify that for the onshore region + observations cover 10% of the spatial domain and 20% of the + temporal domain, while for the offshore region observations cover + 5% of the spatial domain and 10% of the temporal domain. Instead of + a single float, these can also be lists to specify a lower and + upper bound for the spatial and temporal fractions, in which case + the actual fraction for each batch will be sampled uniformly + between these bounds. mode : str Mode for sampling data. Options are 'lazy' or 'eager'. 'eager' mode pre-loads all data into memory as numpy arrays for faster access. @@ -74,18 +102,59 @@ def __init__( """ super().__init__(data=data) feature_sets = feature_sets or {} - self.features = feature_sets.get('features', self.data.features) - self._lr_only_features = feature_sets.get('lr_only_features', []) + self._lr_features = feature_sets.get('lr_features', self.data.features) self._hr_exo_features = feature_sets.get('hr_exo_features', []) + self._hr_out_features = feature_sets.get('hr_out_features', []) + self.proxy_obs_kwargs = proxy_obs_kwargs or {} self.mode = mode self.sample_shape = sample_shape or (10, 10, 1) self.batch_size = batch_size - self.lr_features = self.features self.preflight() + self.check_feature_consistency() + + @property + def use_proxy_obs(self): + """Whether to use proxy observations. When True, proxy observation + features are generated by masking the corresponding gridded ground + truth data and are appended to the samples. The obs features are + specified by the ``obs_features`` argument and should have a + corresponding source feature in the data features that is used for + sampling. For example, an obs feature named ``temperature_obs`` would + be generated from the gridded ground truth feature named + ``temperature``. + """ + return bool(self.proxy_obs_kwargs) + + @property + def onshore_obs_frac(self): + """Fraction of onshore observations to include in each batch when using + proxy observations. This can be a single float or a dictionary with + keys 'spatial' and 'temporal' to specify the fraction for each domain. + If a dictionary is provided, the actual fraction for each batch will be + sampled uniformly between the specified spatial and temporal fractions. + """ + return self.proxy_obs_kwargs.get('onshore_obs_frac', {}) + + @property + def offshore_obs_frac(self): + """Fraction of offshore observations to include in each batch when + using proxy observations. This can be a single float or a dictionary + with keys 'spatial' and 'temporal' to specify the fraction for each + domain. If a dictionary is provided, the actual fraction for each + batch will be sampled uniformly between the specified spatial and + temporal fractions. + """ + return self.proxy_obs_kwargs.get('offshore_obs_frac', {}) def get_sample_index(self, n_obs=None): """Randomly gets spatiotemporal sample index. + Returns + ------- + sample_index : tuple + Tuple of latitude slice, longitude slice, time slice, and features. + Used to get single observation like ``self.data[sample_index]`` + Notes ----- If ``n_obs > 1`` this will get a time slice with ``n_obs * @@ -93,23 +162,21 @@ def get_sample_index(self, n_obs=None): ``n_obs`` samples each with ``self.sample_shape[2]`` time steps. This is a much more efficient way of getting batches of samples but only works if there are enough continuous time steps to sample. - - Returns - ------- - sample_index : tuple - Tuple of latitude slice, longitude slice, time slice, and features. - Used to get single observation like ``self.data[sample_index]`` """ n_obs = n_obs or self.batch_size spatial_slice = uniform_box_sampler(self.shape, self.sample_shape[:2]) time_slice = uniform_time_sampler( self.shape, self.sample_shape[2] * n_obs ) - return (*spatial_slice, time_slice, self.features) + feats = ( + [f for f in self.hr_source_features if f not in self.obs_features] + if self.use_proxy_obs + else self.hr_source_features + ) + return (*spatial_slice, time_slice, feats) def preflight(self): - """Check if the sample_shape is larger than the requested raster - size""" + """Perform shape and feature checks.""" good_shape = ( self.sample_shape[0] <= self.data.shape[0] and self.sample_shape[1] <= self.data.shape[1] @@ -144,6 +211,74 @@ def preflight(self): logger.info('Received mode = "eager".') _ = self.compute() + def check_feature_consistency(self): + """Check that the feature sets are consistent with each other and the + obs features are configured correctly.""" + if self.use_proxy_obs and not all( + f in self.hr_source_features for f in self.obs_features + ): + msg = ( + 'When using proxy observations, all obs features must be ' + 'included either in hr_out_features or hr_exo_features.' + ) + raise ValueError(msg) + + if self.use_proxy_obs and any( + f in self.data.features for f in self.obs_features + ): + msg = ( + f'Obs features {self.obs_features} cannot be in the data ' + f'features {self.data.features} when using proxy observations.' + ) + raise ValueError(msg) + + if len(self.obs_features) > 0 and any( + f in self.hr_exo_features for f in self.obs_features + ): + msg = ( + f'Obs features {self.obs_features} must come at the end of ' + f'the hr_exo_features {self.hr_exo_features}' + ) + assert list(self.obs_features) == list( + self.hr_exo_features[-len(self.obs_features) :] + ), msg + + if len(self.hr_exo_features) > 0: + msg = ( + f'hr_exo_features {self.hr_exo_features} must come at the end ' + f'of the full high-res feature set: {self.hr_source_features}' + ) + assert list(self.hr_exo_features) == list( + self.hr_source_features[-len(self.hr_exo_features) :] + ), msg + + assert all( + f in lowered(self.data.features) for f in self.lr_features + ), ( + f'All lr_features {self.lr_features} must be in the data features ' + f'{self.data.features}.' + ) + assert all( + f in lowered(self.data.features) for f in self.hr_out_features + ), ( + f'All hr_out_features {self.hr_out_features} must be in the data ' + f'features {self.data.features}.' + ) + if not self.use_proxy_obs: + assert all( + f in lowered(self.data.features) for f in self.hr_exo_features + ), ( + f'All hr_exo_features {self.hr_exo_features} must be in the ' + f'data features {self.data.features} when not using proxy ' + 'observations.' + ) + else: + feats = set(self.hr_exo_features) - set(self.obs_features) + assert all(f in lowered(self.data.features) for f in feats), ( + f'All non-obs hr_exo_features {feats} must be in the data ' + f'features {self.data.features} when using proxy observations.' + ) + @property def sample_shape(self) -> tuple: """Shape of the data sample to select when ``__next__()`` is called.""" @@ -256,8 +391,10 @@ def _fast_batch(self): out = self.data.sample(self.get_sample_index(n_obs=self.batch_size)) out = self._compute_samples(out) if isinstance(out, tuple): - return tuple(self._reshape_samples(o) for o in out) - return self._reshape_samples(out) + out = tuple(self._reshape_samples(o) for o in out) + else: + out = self._reshape_samples(out) + return self._append_obs_features(out) def _slow_batch(self): """Get batch of samples with random time slices.""" @@ -266,23 +403,82 @@ def _slow_batch(self): for _ in range(self.batch_size) ] out = self._compute_samples(out) - return self._stack_samples(out) + out = self._stack_samples(out) + return self._append_obs_features(out) def _fast_batch_possible(self): return self.batch_size * self.sample_shape[2] <= self.data.shape[2] + def _get_proxy_obs(self, hi_res): + """Generate proxy observation data by masking the gridded high-res + data. Unobserved locations are set to NaN. + + Parameters + ---------- + hi_res : np.ndarray + High resolution batch data with shape: + (batch_size, spatial_1, spatial_2, temporal, n_features) + + Returns + ------- + obs : np.ndarray + Observation data with NaN for unobserved locations. Shape: + (batch_size, spatial_1, spatial_2, temporal, n_obs_features) + """ + obs_mask = self._get_full_obs_mask(hi_res) + obs = hi_res[..., self.obs_features_ind].copy() + obs[obs_mask[..., : obs.shape[-1]]] = np.nan + return obs + + def _append_obs_features(self, samples): + """Append proxy observation features to the batch samples when + ``use_proxy_obs=True``. The obs features are generated by masking + the corresponding gridded ground truth features. + + Parameters + ---------- + samples : np.ndarray | tuple[np.ndarray, ...] + Batch samples from the data source. For single datasets, shape + is (batch_size, s1, s2, t, n_features). For dual datasets, + this is a tuple of arrays. + + Returns + ------- + samples : np.ndarray | tuple[np.ndarray, ...] + Same as input, but with obs features appended to the last + dimension if proxy obs are enabled. + """ + if not self.use_proxy_obs: + return samples + + if isinstance(samples, tuple): + # For dual datasets, obs features are appended to the high-res + # member (last element) + hr = samples[-1] + obs = self._get_proxy_obs(hr) + hr = np.concatenate([hr, obs], axis=-1) + return (*samples[:-1], hr) + + obs = self._get_proxy_obs(samples) + return np.concatenate([samples, obs], axis=-1) + def __next__(self): """Get next batch of samples. This retrieves n_samples = batch_size with shape = sample_shape from the `.data` (a xr.Dataset or Sup3rDataset) through the Sup3rX accessor. + When ``use_proxy_obs=True`` and ``obs_features`` are configured, proxy + observation features are generated by masking the corresponding + gridded ground truth data and are appended to the samples. + Returns ------- samples : tuple(np.ndarray | da.core.Array) | np.ndarray | da.core.Array Either a tuple or single array of samples. This is a tuple when this method is sampling from a ``Sup3rDataset`` with two data - members - """ # pylint: disable=line-too-long # noqa + members. When proxy obs are enabled, obs features are appended + to the feature dimension. + """ # noqa: E501 if self._fast_batch_possible(): return self._fast_batch() return self._slow_batch() @@ -311,10 +507,49 @@ def _parse_features(self, unparsed_feats): return lowered(parsed_feats) @property - def lr_only_features(self): - """List of feature names or patt*erns that should only be included in - the low-res training set and not the high-res observations.""" - return self._parse_features(self._lr_only_features) + def lr_features(self): + """List of feature names or patt*erns to use as low-resolution model + inputs. If no entry is provided then all available features from the + data will be used.""" + return self._parse_features(self._lr_features) + + @property + def hr_source_features(self): + """List of feature names or patt*erns that should be available natively + as high-resolution. For a non-dual sampler this is all features, since + even features only provided to the model as low-resolution still need + to be coarsened from the high-resolution data. This is in contrast to + dual samplers + (:class:`~sup3r.preprocessing.samplers.dual.DualSampler`), where there + are separate high-resolution and low-resolution data members.""" + feats = [f for f in self.lr_features if f not in self.hr_exo_features] + feats += [ + f + for f in self.hr_out_features + if f not in feats and f not in self.hr_exo_features + ] + feats += [f for f in self.hr_exo_features if f not in feats] + return feats + + @property + def hr_features(self): + """List of feature names or patt*erns that the model is shown at + high-resolution. This does not include features that are only shown to + the model after coarsening. Thus, this includes hr_out_features and + and hr_exo_features.""" + out = [ + f for f in self.hr_out_features if f not in self.hr_exo_features + ] + out += self.hr_exo_features + return out + + @property + def hr_out_features(self): + """List of feature names or patt*erns that should be output by the + generative model. If no entry is provided then all features in + hr_features will be used.""" + hr_out = self._parse_features(self._hr_out_features) + return self.lr_features if len(hr_out) == 0 else hr_out @property def hr_exo_features(self): @@ -322,63 +557,133 @@ def hr_exo_features(self): for training e.g., mid-network high-res topo injection. These must come at the end of the high-res feature set. These can also be input to the model as low-res features.""" - self._hr_exo_features = self._parse_features(self._hr_exo_features) - - if len(self._hr_exo_features) > 0: - msg = ( - f'High-res train-only features "{self._hr_exo_features}" ' - f'do not come at the end of the full high-res feature set: ' - f'{self.features}' - ) - last_feat = self.features[-len(self._hr_exo_features) :] - assert list(self._hr_exo_features) == list(last_feat), msg - - return self._hr_exo_features + return self._parse_features(self._hr_exo_features) @property - def hr_out_features(self): - """Get a list of high-resolution features that are intended to be - output by the GAN. Does not include high-resolution exogenous - features""" - - out = [] - for feature in self.features: - lr_only = any( - fnmatch(feature.lower(), pattern.lower()) - for pattern in self.lr_only_features - ) - ignore = lr_only or feature in self.hr_exo_features - if not ignore: - out.append(feature) - - if len(out) == 0: - msg = ( - f'It appears that all handler features "{self.features}" ' - 'were specified as `hr_exo_features` or `lr_only_features` ' - 'and therefore there are no output features!' - ) - logger.error(msg) - raise RuntimeError(msg) - - return lowered(out) + def obs_features(self): + """List of feature names or patt*erns that should be treated as + observations. These features will be included in the high-res data but + not the low-res data and won't necessarily be expected to be output by + the generative model. These are different from other `hr_exo_features` + in that they are intended to be used as observation features with NaN + values where observations are not available.""" + return [f for f in self.hr_source_features if '_obs' in f] @property def hr_features_ind(self): """Get the high-resolution feature channel indices that should be - included for training. Any high-resolution features that are only - included in the data handler to be coarsened for the low-res input are - removed""" - hr_features = list(self.hr_out_features) + list(self.hr_exo_features) - if list(self.features) == hr_features: - return np.arange(len(self.features)) - return [ - i - for i, feature in enumerate(self.features) - if feature in hr_features - ] + included for loss calculations. This includes hr_out_features and + hr_exo_features, Any high-resolution features that are only included in + the data handler to be coarsened for the low-res input are removed. + """ + return [self.hr_source_features.index(f) for f in self.hr_features] @property - def hr_features(self): - """Get the high-resolution features corresponding to - `hr_features_ind`""" - return [self.features[ind].lower() for ind in self.hr_features_ind] + def lr_features_ind(self): + """Get the low-resolution feature channel indices that should be + included for training. This includes lr_features. + """ + return [self.hr_source_features.index(f) for f in self.lr_features] + + @property + def obs_features_ind(self): + """Get the source feature indices in ``features`` for each obs + feature. Each obs feature named ``_obs`` maps to the + corresponding ```` in the features. + + Returns + ------- + list[int] + Indices into ``features`` for each obs feature source. + """ + check_feats = ( + [f.replace('_obs', '') for f in self.obs_features] + if self.use_proxy_obs + else self.obs_features + ) + + return [self.hr_source_features.index(f) for f in check_feats] + + def _get_obs_mask(self, hi_res, spatial_frac, time_frac=1.0): + """Get observation mask for a given spatial and temporal obs + fraction for an entire batch. This is divided between spatial and + temporal fractions because often the spatial fraction is significantly + lower than the temporal fraction in practice, e.g. for a given spatial + location there might be observations for most of the time period but + only a small fraction of the spatial domain is observed. + + Parameters + ---------- + hi_res : np.ndarray + True high resolution data for the entire batch. + spatial_frac : float | list + Fraction of the spatial domain that should be treated as + observations. This is a value between 0 and 1 or a list with + lower and upper bounds for the spatial fraction. + time_frac : float | list, optional + Fraction of the temporal domain that should be treated as + observations. This is a value between 0 and 1 or a list with + lower and upper bounds for the temporal fraction. Default is 1.0 + + Returns + ------- + np.ndarray + Mask which is True for locations that are not observed and False + for locations that are observed. + (n_obs, spatial_1, spatial_2, n_features) + (n_obs, spatial_1, spatial_2, n_temporal, n_features) + + Notes + ----- + The output mask is repeated along the feature dimension, so each + feature will have the same observation mask. The output mask is not + repeated along the batch dimension, so each sample in the batch will + have a different observation mask. + """ + s_range = ( + spatial_frac + if isinstance(spatial_frac, (list, tuple)) + else [spatial_frac, spatial_frac] + ) + t_range = ( + time_frac + if isinstance(time_frac, (list, tuple)) + else [time_frac, time_frac] + ) + n_obs, n_spatial_1, n_spatial_2, n_temporal = hi_res.shape[:-1] + n_features = len(self.obs_features) + + s_fracs = RANDOM_GENERATOR.uniform(*s_range, size=n_obs) + t_fracs = RANDOM_GENERATOR.uniform(*t_range, size=n_obs) + s_fracs = np.clip(s_fracs, 0, 1) + t_fracs = np.clip(t_fracs, 0, 1) + + s_mask = RANDOM_GENERATOR.uniform( + size=(n_obs, n_spatial_1, n_spatial_2) + ) + s_mask = s_mask <= s_fracs[:, None, None] + s_mask = s_mask[..., None, None] + + t_mask = RANDOM_GENERATOR.uniform(size=(n_obs, n_temporal)) + t_mask = t_mask <= t_fracs[:, None] + t_mask = t_mask[:, None, None, :, None] + + mask = ~(s_mask & t_mask) + return np.repeat(mask, n_features, axis=-1) + + def _get_full_obs_mask(self, hi_res): + """Define observation mask for the current batch. This differs from + ``_get_obs_mask`` by defining a composite mask based on separate + onshore and offshore masks. This is because there is often more + observation data available onshore than offshore.""" + on_sf = self.onshore_obs_frac.get('spatial', 0.0) + on_tf = self.onshore_obs_frac.get('time', 1.0) + obs_mask = self._get_obs_mask(hi_res, on_sf, on_tf) + if 'topography' in self.hr_source_features and self.offshore_obs_frac: + topo_idx = self.hr_source_features.index('topography') + topo = hi_res[..., topo_idx] + off_sf = self.offshore_obs_frac.get('spatial', 0.0) + off_tf = self.offshore_obs_frac.get('time', 1.0) + offshore_mask = self._get_obs_mask(hi_res, off_sf, off_tf) + obs_mask = np.where(topo[..., None] > 0, obs_mask, offshore_mask) + return obs_mask diff --git a/sup3r/preprocessing/samplers/cc.py b/sup3r/preprocessing/samplers/cc.py index 85af6d88c1..57fead2dbf 100644 --- a/sup3r/preprocessing/samplers/cc.py +++ b/sup3r/preprocessing/samplers/cc.py @@ -54,17 +54,30 @@ def __init__( Temporal enhancement factor feature_sets : Optional[dict] Optional dictionary describing how the full set of features is - split between ``lr_only_features`` and ``hr_exo_features``. - - lr_only_features : list | tuple - List of feature names or patt*erns that should only be - included in the low-res training set and not the high-res - observations. + split between ``lr_features``, ``hr_exo_features``, and + ``hr_out_features``. + + lr_features : list | tuple + List of feature names or patt*erns to use as low-resolution + model inputs. If no entry is provided then all available + features from the data will be used. + hr_out_features : list | tuple + List of feature names or patt*erns that should be output + by the generative model and available as ground truth targets. + If no entry is provided then all features in lr_features will + be used. hr_exo_features : list | tuple - List of feature names or patt*erns that should be included - in the high-resolution observation but not expected to be - output from the generative model. An example is high-res - topography that is to be injected mid-network. + List of feature names or patt*erns that should be available + as high-resolution model inputs (like topography or + observations). These are injected into the model mid-network + to condition output on high-resolution information. The model + configuration should have the appropriate layers to use these + features. e.g. ``Sup3rConcat`` for topography injection, + ``Sup3rObsModel`` or ``Sup3rCrossAttention`` for obs injection. + If no entry is provided then hr_exo_features will be empty. + + *To include sparse features as inputs or targets the features + must have an "_obs" suffix. mode : str Mode for sampling data. Options are 'lazy' or 'eager'. 'eager' mode pre-loads all data into memory as numpy arrays for faster access. @@ -100,21 +113,21 @@ def __init__( mode=mode, ) - def check_for_consistent_shapes(self): + def check_shape_consistency(self): """Make sure container shapes and sample shapes are compatible with enhancement factors.""" enhanced_shape = ( - self.lr_data.shape[0] * self.s_enhance, - self.lr_data.shape[1] * self.s_enhance, - self.lr_data.shape[2] * (1 if self.t_enhance == 1 else 24), + self.data.low_res.shape[0] * self.s_enhance, + self.data.low_res.shape[1] * self.s_enhance, + self.data.low_res.shape[2] * (1 if self.t_enhance == 1 else 24), ) msg = ( - f'hr_data.shape {self.hr_data.shape} and enhanced ' + f'hr_data.shape {self.data.high_res.shape} and enhanced ' f'lr_data.shape {enhanced_shape} are not compatible with ' f'the given enhancement factors t_enhance = {self.t_enhance}, ' f's_enhance = {self.s_enhance}' ) - assert self.hr_data.shape[:3] == enhanced_shape, msg + assert self.data.high_res.shape[:3] == enhanced_shape, msg def reduce_high_res_sub_daily(self, high_res, csr_ind=0): """Take an hourly high-res observation and reduce the temporal axis diff --git a/sup3r/preprocessing/samplers/dc.py b/sup3r/preprocessing/samplers/dc.py index a48c80df57..2312ce9ced 100644 --- a/sup3r/preprocessing/samplers/dc.py +++ b/sup3r/preprocessing/samplers/dc.py @@ -56,8 +56,30 @@ def __init__( efficient than getting N = batch_size samples and then stacking. feature_sets : Optional[dict] Optional dictionary describing how the full set of features is - split between `lr_only_features` and `hr_exo_features`. See - :class:`~sup3r.preprocessing.Sampler` + split between ``lr_features``, ``hr_exo_features``, and + ``hr_out_features``. + + lr_features : list | tuple + List of feature names or patt*erns to use as low-resolution + model inputs. If no entry is provided then all available + features from the data will be used. + hr_out_features : list | tuple + List of feature names or patt*erns that should be output + by the generative model and available as ground truth targets. + If no entry is provided then all features in lr_features will + be used. + hr_exo_features : list | tuple + List of feature names or patt*erns that should be available + as high-resolution model inputs (like topography or + observations). These are injected into the model mid-network + to condition output on high-resolution information. The model + configuration should have the appropriate layers to use these + features. e.g. ``Sup3rConcat`` for topography injection, + ``Sup3rObsModel`` or ``Sup3rCrossAttention`` for obs injection. + If no entry is provided then hr_exo_features will be empty. + + *To include sparse features as inputs or targets the features + must have an "_obs" suffix. mode : str Loading mode for sampling. See :class:`~sup3r.preprocessing.Sampler` diff --git a/sup3r/preprocessing/samplers/dual.py b/sup3r/preprocessing/samplers/dual.py index 8eb8a446c0..a3c26bcfa2 100644 --- a/sup3r/preprocessing/samplers/dual.py +++ b/sup3r/preprocessing/samplers/dual.py @@ -6,7 +6,6 @@ from typing import Optional from sup3r.preprocessing.base import Sup3rDataset -from sup3r.preprocessing.utilities import lowered from .base import Sampler from .utilities import uniform_box_sampler, uniform_time_sampler @@ -29,6 +28,7 @@ def __init__( s_enhance: int = 1, t_enhance: int = 1, feature_sets: Optional[dict] = None, + proxy_obs_kwargs: Optional[dict] = None, mode: str = 'lazy', ): """ @@ -36,7 +36,7 @@ def __init__( ---------- data : Sup3rDataset A :class:`~sup3r.preprocessing.base.Sup3rDataset` instance with - low-res and high-res data members, and optionally an obs member. + low-res and high-res data members. sample_shape : tuple Size of arrays to sample from the high-res data. The sample shape for the low-res sampler will be determined from the enhancement @@ -47,17 +47,46 @@ def __init__( Temporal enhancement factor feature_sets : Optional[dict] Optional dictionary describing how the full set of features is - split between `lr_only_features` and `hr_exo_features`. + split between ``lr_features``, ``hr_exo_features``, and + ``hr_out_features``. - lr_only_features : list | tuple - List of feature names or patt*erns that should only be - included in the low-res training set and not the high-res - observations. + lr_features : list | tuple + List of feature names or patt*erns to use as low-resolution + model inputs. If no entry is provided then all available + features from the data will be used. + hr_out_features : list | tuple + List of feature names or patt*erns that should be output + by the generative model and available as ground truth targets. + If no entry is provided then all features in lr_features will + be used. hr_exo_features : list | tuple - List of feature names or patt*erns that should be included - in the high-resolution observation but not expected to be - output from the generative model. An example is high-res - topography that is to be injected mid-network. + List of feature names or patt*erns that should be available + as high-resolution model inputs (like topography or + observations). These are injected into the model mid-network + to condition output on high-resolution information. The model + configuration should have the appropriate layers to use these + features. e.g. ``Sup3rConcat`` for topography injection, + ``Sup3rObsModel`` or ``Sup3rCrossAttention`` for obs injection. + If no entry is provided then hr_exo_features will be empty. + + *To include sparse features as inputs or targets the features + must have an "_obs" suffix. + proxy_obs_kwargs : dict | None + Optional dictionary of keyword arguments to pass to the proxy + observation generator. This is only used when training with proxy + observations. Keys can include ``onshore_obs_frac`` and + ``offshore_obs_frac`` which specify the fraction of the batch that + should be treated as onshore and offshore observations, + respectively. For example, ``proxy_obs_kwargs={ 'onshore_obs_frac': + { 'spatial': 0.1, 'temporal': 0.2}, 'offshore_obs_frac': { + 'spatial': 0.05, 'temporal': 0.1} }`` would specify that for the + onshore region observations cover 10% of the spatial domain and 20% + of the temporal domain, while for the offshore region observations + cover 5% of the spatial domain and 10% of the temporal domain. + Instead of a single float, these can also be lists to specify a + lower and upper bound for the spatial and temporal fractions, in + which case the actual fraction for each batch will be sampled + uniformly between these bounds. mode : str Mode for sampling data. Options are 'lazy' or 'eager'. 'eager' mode pre-loads all data into memory as numpy arrays for faster access. @@ -66,38 +95,39 @@ def __init__( """ msg = ( f'{self.__class__.__name__} requires a Sup3rDataset object ' - 'with `.low_res` and `.high_res` data members, and optionally an ' - '`.obs` member, in that order' + 'with `.low_res` and `.high_res` data members, in that order' ) - dnames = ['low_res', 'high_res', 'obs'][: len(data)] - check = ( + check = all( hasattr(data, dname) and getattr(data, dname) == data[i] - for i, dname in enumerate(dnames) + for i, dname in enumerate(['low_res', 'high_res']) ) assert check, msg - super().__init__( - data=data, - sample_shape=sample_shape, - batch_size=batch_size, - mode=mode, + self.data = data + feature_sets = feature_sets or {} + self._lr_features = feature_sets.get( + 'lr_features', self.data.low_res.features ) - self.lr_data, self.hr_data = self.data.low_res, self.data.high_res + self._hr_exo_features = feature_sets.get('hr_exo_features', []) + self._hr_out_features = feature_sets.get( + 'hr_out_features', self.data.high_res.features + ) + self.proxy_obs_kwargs = proxy_obs_kwargs or {} + self.mode = mode + self.sample_shape = sample_shape or (10, 10, 1) + self.batch_size = batch_size + self.lr_sample_shape = ( self.hr_sample_shape[0] // s_enhance, self.hr_sample_shape[1] // s_enhance, self.hr_sample_shape[2] // t_enhance, ) - feature_sets = feature_sets or {} - self._lr_only_features = feature_sets.get('lr_only_features', []) - self._hr_exo_features = feature_sets.get('hr_exo_features', []) - self.features = self.get_features(feature_sets) - self.lr_features = [ - f for f in self.features if f in self.lr_data.features - ] self.s_enhance = s_enhance self.t_enhance = t_enhance - self.check_for_consistent_shapes() + + self.preflight() + self.check_shape_consistency() + self.check_feature_consistency() post_init_args = { 'lr_sample_shape': self.lr_sample_shape, 'hr_sample_shape': self.hr_sample_shape, @@ -106,33 +136,29 @@ def __init__( } self.post_init_log(post_init_args) - def get_features(self, feature_sets): - """Return default set of features composed from data vars in low res - and high res data objects or the value provided through the - feature_sets dictionary.""" - features = [] - _ = [ - features.append(f) - for f in [*self.lr_data.features, *self.hr_data.features] - if f not in features and f not in lowered(self._hr_exo_features) + @property + def hr_source_features(self): + """Features available natively at high-resolution.""" + out = [ + f for f in self.hr_out_features if f not in self.hr_exo_features ] - features += lowered(self._hr_exo_features) - return feature_sets.get('features', features) + out += self.hr_exo_features + return out - def check_for_consistent_shapes(self): + def check_shape_consistency(self): """Make sure container shapes are compatible with enhancement factors.""" enhanced_shape = ( - self.lr_data.shape[0] * self.s_enhance, - self.lr_data.shape[1] * self.s_enhance, - self.lr_data.shape[2] * self.t_enhance, + self.data.low_res.shape[0] * self.s_enhance, + self.data.low_res.shape[1] * self.s_enhance, + self.data.low_res.shape[2] * self.t_enhance, ) msg = ( - f'hr_data.shape {self.hr_data.shape[:-1]} and enhanced ' + f'hr_data.shape {self.data.high_res.shape[:-1]} and enhanced ' f'lr_data.shape {enhanced_shape} are not compatible with ' 'the given enhancement factors' ) - assert self.hr_data.shape[:-1] == enhanced_shape, msg + assert self.data.high_res.shape[:-1] == enhanced_shape, msg def get_sample_index(self, n_obs=None): """Get paired sample index, consisting of index for the low res sample @@ -141,10 +167,10 @@ def get_sample_index(self, n_obs=None): includes observation data.""" n_obs = n_obs or self.batch_size spatial_slice = uniform_box_sampler( - self.lr_data.shape, self.lr_sample_shape[:2] + self.data.low_res.shape, self.lr_sample_shape[:2] ) time_slice = uniform_time_sampler( - self.lr_data.shape, self.lr_sample_shape[2] * n_obs + self.data.low_res.shape, self.lr_sample_shape[2] * n_obs ) lr_index = (*spatial_slice, time_slice, self.lr_features) hr_index = [ @@ -155,8 +181,11 @@ def get_sample_index(self, n_obs=None): slice(s.start * self.t_enhance, s.stop * self.t_enhance) for s in lr_index[2:-1] ] - obs_index = (*hr_index, self.hr_out_features) - hr_index = (*hr_index, self.hr_features) + hr_feats = ( + [f for f in self.hr_source_features if f not in self.obs_features] + if self.use_proxy_obs + else self.hr_source_features + ) + hr_index = (*hr_index, hr_feats) - sample_index = (lr_index, hr_index, obs_index) - return sample_index[: len(self.data)] + return (lr_index, hr_index) diff --git a/sup3r/utilities/loss_metrics.py b/sup3r/utilities/loss_metrics.py index 07600a91f4..cda2188898 100644 --- a/sup3r/utilities/loss_metrics.py +++ b/sup3r/utilities/loss_metrics.py @@ -9,26 +9,33 @@ from tensorflow.keras.losses import MeanAbsoluteError, MeanSquaredError -class PhysicsBasedLoss(tf.keras.losses.Loss): - """Base class for physics-based loss metrics. This is meant to be used as a +class Sup3rLoss(tf.keras.losses.Loss): + """Base class for custom sup3r loss metrics. This is meant to be used as a base class for loss metrics that require specific input features.""" - def __init__(self, input_features='all'): + def __init__(self, gen_features='all', true_features=None): """Initialize the loss with given input features Parameters ---------- - input_features : list | str - List of input features that the loss metric will be calculated on. - This is meant to be used for physics-based loss metrics that - require specific input features. If 'all', the loss will be - calculated on all features. Otherwise, the loss will be calculated - on the features specified in the list. The order of features in - the list will be checked to determine the order of features in the - input tensors. + gen_features : list | str + List of generator output features that the loss metric will be + calculated on. If 'all', the loss will be calculated on all + generator features. Otherwise, the loss will be calculated on the + features specified in the list. The order of features in the list + will be checked to determine the order of features in the generator + output tensor. + true_features : list | str + List of true features that the loss metric will be calculated on. + If None, this will be the same as gen_features. The order of + features in the list will be checked to determine the order of + features in the ground truth tensor. """ super().__init__() - self.input_features = input_features + self.gen_features = gen_features + self.true_features = ( + true_features if true_features is not None else gen_features + ) def tf_derivative(x, axis=1): @@ -117,7 +124,7 @@ def gaussian_kernel(x1, x2, sigma=1.0): return result -class ExpLoss(tf.keras.losses.Loss): +class ExpLoss(Sup3rLoss): """Loss class for squared exponential difference""" def __call__(self, x1, x2): @@ -140,7 +147,7 @@ def __call__(self, x1, x2): return tf.reduce_mean(1 - tf.exp(-((x1 - x2) ** 2))) -class MmdLoss(tf.keras.losses.Loss): +class MmdLoss(Sup3rLoss): """Loss class for max mean discrepancy loss""" def __call__(self, x1, x2, sigma=1.0): @@ -169,7 +176,7 @@ def __call__(self, x1, x2, sigma=1.0): return mmd -class SpatialDerivativeLoss(tf.keras.losses.Loss): +class SpatialDerivativeLoss(Sup3rLoss): """Loss class to encourage accurary of spatial derivatives.""" LOSS_METRIC = MeanAbsoluteError() @@ -204,7 +211,7 @@ def __call__(self, x1, x2): return self.LOSS_METRIC(x1_div, x2_div) -class TemporalDerivativeLoss(tf.keras.losses.Loss): +class TemporalDerivativeLoss(Sup3rLoss): """Loss class to encourage accurary of temporal derivative.""" LOSS_METRIC = MeanAbsoluteError() @@ -238,7 +245,7 @@ def __call__(self, x1, x2): return self.LOSS_METRIC(x1_div, x2_div) -class CoarseMseLoss(tf.keras.losses.Loss): +class CoarseMseLoss(Sup3rLoss): """Loss class for coarse mse on spatial average of 5D tensor""" MSE_LOSS = MeanSquaredError() @@ -266,7 +273,7 @@ def __call__(self, x1, x2): return self.MSE_LOSS(x1_coarse, x2_coarse) -class SpatialExtremesLoss(tf.keras.losses.Loss): +class SpatialExtremesLoss(Sup3rLoss): """Loss class that encourages accuracy of the min/max values in the spatial domain. This does not include an additional MAE term""" @@ -301,7 +308,7 @@ def __call__(self, x1, x2): return (mae_min + mae_max) / 2 -class TemporalExtremesLoss(tf.keras.losses.Loss): +class TemporalExtremesLoss(Sup3rLoss): """Loss class that encourages accuracy of the min/max values in the timeseries. This does not include an additional mae term""" @@ -336,7 +343,7 @@ def __call__(self, x1, x2): return (mae_min + mae_max) / 2 -class SpatialFftLoss(tf.keras.losses.Loss): +class SpatialFftLoss(Sup3rLoss): """Loss class that encourages accuracy of the spatial frequency spectrum""" MAE_LOSS = MeanAbsoluteError() @@ -381,7 +388,7 @@ def __call__(self, x1, x2): return self.MAE_LOSS(x1_hat, x2_hat) -class SpatiotemporalFftLoss(tf.keras.losses.Loss): +class SpatiotemporalFftLoss(Sup3rLoss): """Loss class that encourages accuracy of the spatiotemporal frequency spectrum""" @@ -429,7 +436,7 @@ def __call__(self, x1, x2): return self.MAE_LOSS(x1_hat, x2_hat) -class LowResLoss(tf.keras.losses.Loss): +class LowResLoss(Sup3rLoss): """Content loss that is calculated by coarsening the synthetic and true high-resolution data pairs and then performing the pointwise content loss on the low-resolution fields""" @@ -582,7 +589,7 @@ def __call__(self, x1, x2): return self._tf_loss(x1, x2) + ex_loss -class PerceptualLoss(tf.keras.losses.Loss): +class PerceptualLoss(Sup3rLoss): """Perceptual loss that is calculated as MSE between feature maps of ground truth and synthetic data""" @@ -665,7 +672,7 @@ def __call__(self, x1, x2): return tf.reduce_mean(losses) -class SlicedWassersteinLoss(tf.keras.losses.Loss): +class SlicedWassersteinLoss(Sup3rLoss): """Loss class for sliced wasserstein distance loss""" def __init__(self, n_projections=1024): @@ -674,7 +681,7 @@ def __init__(self, n_projections=1024): n_projections : int number of random 1D projections to use - Note + Note: ---- Experimentally, we get stability in the SW metric when n_projections is at least 30% of the number of projection dimensions, which for us @@ -733,7 +740,7 @@ def __call__(self, x1, x2): return tf.reduce_mean((x1_sorted - x2_sorted) ** 2) -class MaterialDerivativeLoss(PhysicsBasedLoss): +class MaterialDerivativeLoss(Sup3rLoss): """Loss class for the material derivative. This is the left hand side of the Navier-Stokes equation and is equal to internal + external forces divided by density in general. Under certain simplifying assumptions, this @@ -746,19 +753,19 @@ class MaterialDerivativeLoss(PhysicsBasedLoss): LOSS_METRIC = MeanAbsoluteError() - def __init__(self, input_features): - super().__init__(input_features=input_features) + def __init__(self, gen_features): + super().__init__(gen_features=gen_features) self.u_inds = [ - i for i, f in enumerate(input_features) if f.startswith('u_') + i for i, f in enumerate(gen_features) if f.startswith('u_') ] self.v_inds = [ - i for i, f in enumerate(input_features) if f.startswith('v_') + i for i, f in enumerate(gen_features) if f.startswith('v_') ] self.u_heights = [ - f.split('_')[1] for f in input_features if f.startswith('u_') + f.split('_')[1] for f in gen_features if f.startswith('u_') ] self.v_heights = [ - f.split('_')[1] for f in input_features if f.startswith('v_') + f.split('_')[1] for f in gen_features if f.startswith('v_') ] assert len(self.u_inds) == len(self.v_inds), ( 'The number of u and v components must be equal for ' @@ -787,7 +794,7 @@ def _compute_md(self, x, feature): """ # df/dt height = feature.split('_')[1] - fidx = self.input_features.index(feature) + fidx = self.gen_features.index(feature) uidx = self.u_inds[self.u_heights.index(height)] vidx = self.v_inds[self.v_heights.index(height)] x_div = tf_derivative(x[..., fidx], axis=3) @@ -827,16 +834,16 @@ def __call__(self, x1, x2): assert len(x1.shape) == 5 and len(x2.shape) == 5, msg x1_div = tf.stack([ - self._compute_md(x1, feature) for feature in self.input_features + self._compute_md(x1, feature) for feature in self.gen_features ]) x2_div = tf.stack([ - self._compute_md(x2, feature) for feature in self.input_features + self._compute_md(x2, feature) for feature in self.gen_features ]) return self.LOSS_METRIC(x1_div, x2_div) -class GeothermalPhysicsLoss(PhysicsBasedLoss): +class GeothermalPhysicsLoss(Sup3rLoss): """Physics based loss for Geothermal applications TODO: Fill in call with appropriate physics equations. This is currently @@ -847,13 +854,48 @@ class GeothermalPhysicsLoss(PhysicsBasedLoss): def __call__(self, x1, x2): """Geothermal physics loss""" - check = x1.shape[-1] == len(self.input_features) - check &= x2.shape[-1] == len(self.input_features) + check = x1.shape[-1] == len(self.gen_features) + check &= x2.shape[-1] == len(self.true_features) msg = ( - f'Number of features in `x1`: {x1.shape[-1]}, `x2`: ' - f'{x2.shape[-1]} must match the length of `input_features`: ' - f'{len(self.input_features)}' + f'Number of features in `x1`: {x1.shape[-1]} must match the ' + f'length of `gen_features`: {len(self.gen_features)}, `x2`: ' + f'{x2.shape[-1]} must match the length of `true_features`: ' + f'{len(self.true_features)}' ) assert check, msg return self.LOSS_METRIC(x1, x2) + + +class GeothermalPhysicsLossWithObs(Sup3rLoss): + """Physics based loss for Geothermal applications + + TODO: Fill in call with appropriate physics equations. This is currently + just a dummy equation for testing. + """ + + LOSS_METRIC = MeanAbsoluteError() + + def __call__(self, x1, x2): + """Geothermal physics loss""" + check = x1.shape[-1] == len(self.gen_features) + check &= x2.shape[-1] == len(self.true_features) + msg = ( + f'Number of features in `x1`: {x1.shape[-1]} must match the ' + f'length of `gen_features`: {len(self.gen_features)}, `x2`: ' + f'{x2.shape[-1]} must match the length of `true_features`: ' + f'{len(self.true_features)}' + ) + assert check, msg + + mask = tf.math.logical_not(tf.math.is_nan(x2)) + x1m = tf.boolean_mask(x1, mask) + x2m = tf.boolean_mask(x2, mask) + + physics_loss = tf.constant(1e-3, dtype=x1.dtype) + obs_loss = ( + tf.constant(0, dtype=x1.dtype) + if tf.math.reduce_all(tf.math.is_nan(x2m)) + else self.LOSS_METRIC(x1m, x2m) + ) + return physics_loss + obs_loss diff --git a/tests/batch_handlers/test_bh_h5_cc.py b/tests/batch_handlers/test_bh_h5_cc.py index 942e52eed4..f349d6eceb 100644 --- a/tests/batch_handlers/test_bh_h5_cc.py +++ b/tests/batch_handlers/test_bh_h5_cc.py @@ -104,7 +104,10 @@ def test_solar_batching_spatial(): s_enhance=2, t_enhance=1, sample_shape=(20, 20, 1), - feature_sets={'lr_only_features': ['clearsky_ghi', 'ghi']}, + feature_sets={ + 'lr_features': FEATURES_S, + 'hr_out_features': ['clearsky_ratio'], + }, ) for batch in batcher: @@ -162,7 +165,10 @@ def test_solar_multi_day_coarse_data(): s_enhance=4, t_enhance=3, sample_shape=(20, 20, 9), - feature_sets={'lr_only_features': ['clearsky_ghi', 'ghi']}, + feature_sets={ + 'lr_features': FEATURES_S, + 'hr_out_features': ['clearsky_ratio'], + }, ) for batch in batcher: @@ -176,7 +182,10 @@ def test_solar_multi_day_coarse_data(): # run another test with u/v on low res side but not high res features = ['clearsky_ratio', 'u', 'v', 'ghi', 'clearsky_ghi'] - feature_sets = {'lr_only_features': ['u', 'v', 'clearsky_ghi', 'ghi']} + feature_sets = { + 'lr_features': features, + 'hr_out_features': ['clearsky_ratio'], + } handler = DataHandlerH5SolarCC(pytest.FP_NSRDB, features, **dh_kwargs) batcher = BatchHandlerTesterCC( @@ -329,7 +338,10 @@ def test_surf_min_max_vars(): s_enhance=1, t_enhance=24, sample_shape=(20, 20, 72), - feature_sets={'lr_only_features': ['*_min_*', '*_max_*']}, + feature_sets={ + 'lr_features': surf_features, + 'hr_out_features': ['temperature_2m', 'relativehumidity_2m'], + }, mode='eager', ) diff --git a/tests/batch_queues/test_bq_general.py b/tests/batch_queues/test_bq_general.py index 52b2ffb257..2fd0607034 100644 --- a/tests/batch_queues/test_bq_general.py +++ b/tests/batch_queues/test_bq_general.py @@ -184,7 +184,10 @@ def test_pair_batch_queue_with_lr_only_features(): s_enhance=2, t_enhance=2, batch_size=4, - feature_sets={'lr_only_features': lr_only_features}, + feature_sets={ + 'lr_features': lr_features, + 'hr_out_features': FEATURES, + }, ) for lr, hr in zip(lr_containers, hr_containers) ] diff --git a/tests/bias/test_bias_correction.py b/tests/bias/test_bias_correction.py index 34c204f146..0062f3f2e1 100644 --- a/tests/bias/test_bias_correction.py +++ b/tests/bias/test_bias_correction.py @@ -471,10 +471,13 @@ def test_fwp_integration(): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 10, 10, 6, len(features)))) - model.meta['lr_features'] = features - model.meta['hr_out_features'] = features - model.meta['s_enhance'] = 3 - model.meta['t_enhance'] = 4 + model.set_model_params( + lr_features=features, + hr_out_features=features, + s_enhance=3, + t_enhance=4, + input_resolution={'spatial': '12km', 'temporal': '60min'}, + ) with tempfile.TemporaryDirectory() as td: bias_fp = os.path.join(td, 'bc.h5') diff --git a/tests/bias/test_presrat_bias_correction.py b/tests/bias/test_presrat_bias_correction.py index c63ee4da96..62987ca422 100644 --- a/tests/bias/test_presrat_bias_correction.py +++ b/tests/bias/test_presrat_bias_correction.py @@ -757,10 +757,13 @@ def test_fwp_integration(tmp_path, presrat_params, fp_fut_cc): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 10, 10, 6, len(features)))) - model.meta['lr_features'] = features - model.meta['hr_out_features'] = features - model.meta['s_enhance'] = 3 - model.meta['t_enhance'] = 4 + model.set_model_params( + lr_features=features, + hr_out_features=features, + s_enhance=3, + t_enhance=4, + input_resolution={'spatial': '12km', 'temporal': '60min'}, + ) out_dir = os.path.join(tmp_path, 'st_gan') model.save(out_dir) diff --git a/tests/bias/test_qdm_bias_correction.py b/tests/bias/test_qdm_bias_correction.py index d5722d43bf..e3aa709634 100644 --- a/tests/bias/test_qdm_bias_correction.py +++ b/tests/bias/test_qdm_bias_correction.py @@ -139,9 +139,9 @@ def test_qdm_bc(fp_fut_cc): # Each location can be all finite or all NaN, but not both for v in out: tmp = np.isfinite(out[v].reshape(-1, out[v].shape[-1])) - assert np.all( - np.all(tmp, axis=1) == ~np.all(~tmp, axis=1) - ), f'For each location of {v} it should be all finite or nonte' + assert np.all(np.all(tmp, axis=1) == ~np.all(~tmp, axis=1)), ( + f'For each location of {v} it should be all finite or nonte' + ) def test_parallel(fp_fut_cc): @@ -178,9 +178,9 @@ def test_parallel(fp_fut_cc): for k in out_s: assert k in out_p, f'Missing {k} in parallel run' - assert np.allclose( - out_s[k], out_p[k], equal_nan=True - ), f'Different results for {k}' + assert np.allclose(out_s[k], out_p[k], equal_nan=True), ( + f'Different results for {k}' + ) def test_fill_nan(fp_fut_cc): @@ -202,14 +202,14 @@ def test_fill_nan(fp_fut_cc): out = c.run(fill_extend=False) # Ignore non `params` parameters, such as window_center params = (v for v in out if v.endswith('params')) - assert np.all( - [np.isnan(out[v]).any() for v in params] - ), 'Assume at least one NaN value for each param' + assert np.all([np.isnan(out[v]).any() for v in params]), ( + 'Assume at least one NaN value for each param' + ) out = c.run() - assert np.all( - [np.isfinite(v).all() for v in out.values()] - ), 'All NaN values where supposed to be filled' + assert np.all([np.isfinite(v).all() for v in out.values()]), ( + 'All NaN values where supposed to be filled' + ) def test_save_file(tmp_path, fp_fut_cc): @@ -350,7 +350,8 @@ def test_bc_identity(tmp_path, fp_fut_cc, dist_params): idx = ~(np.isnan(original) | np.isnan(corrected)) assert np.allclose( - compute_if_dask(original)[idx], compute_if_dask(corrected)[idx]) + compute_if_dask(original)[idx], compute_if_dask(corrected)[idx] + ) def test_bc_identity_absolute(tmp_path, fp_fut_cc, dist_params): @@ -375,7 +376,8 @@ def test_bc_identity_absolute(tmp_path, fp_fut_cc, dist_params): idx = ~(np.isnan(original) | np.isnan(corrected)) assert np.allclose( - compute_if_dask(original)[idx], compute_if_dask(corrected)[idx]) + compute_if_dask(original)[idx], compute_if_dask(corrected)[idx] + ) def test_bc_model_constant(tmp_path, fp_fut_cc, dist_params): @@ -400,7 +402,8 @@ def test_bc_model_constant(tmp_path, fp_fut_cc, dist_params): idx = ~(np.isnan(original) | np.isnan(corrected)) assert np.allclose( - compute_if_dask(corrected)[idx] - compute_if_dask(original)[idx], -10) + compute_if_dask(corrected)[idx] - compute_if_dask(original)[idx], -10 + ) def test_bc_trend(tmp_path, fp_fut_cc, dist_params): @@ -425,7 +428,8 @@ def test_bc_trend(tmp_path, fp_fut_cc, dist_params): idx = ~(np.isnan(original) | np.isnan(corrected)) assert np.allclose( - compute_if_dask(corrected)[idx] - compute_if_dask(original)[idx], 10) + compute_if_dask(corrected)[idx] - compute_if_dask(original)[idx], 10 + ) def test_bc_trend_same_hist(tmp_path, fp_fut_cc, dist_params): @@ -449,7 +453,8 @@ def test_bc_trend_same_hist(tmp_path, fp_fut_cc, dist_params): idx = ~(np.isnan(original) | np.isnan(corrected)) assert np.allclose( - compute_if_dask(original)[idx], compute_if_dask(corrected)[idx]) + compute_if_dask(original)[idx], compute_if_dask(corrected)[idx] + ) def test_fwp_integration(tmp_path): @@ -493,10 +498,13 @@ def test_fwp_integration(tmp_path): Sup3rGan.seed() model = Sup3rGan(pytest.ST_FP_GEN, pytest.ST_FP_DISC, learning_rate=1e-4) _ = model.generate(np.ones((4, 10, 10, 6, len(features)))) - model.meta['lr_features'] = features - model.meta['hr_out_features'] = features - model.meta['s_enhance'] = 3 - model.meta['t_enhance'] = 4 + model.set_model_params( + lr_features=features, + hr_out_features=features, + s_enhance=3, + t_enhance=4, + input_resolution={'spatial': '12km', 'temporal': '60min'}, + ) bias_fp = os.path.join(tmp_path, 'bc.h5') out_dir = os.path.join(tmp_path, 'st_gan') @@ -568,12 +576,12 @@ def test_fwp_integration(tmp_path): bc_chunk = bc_fwp.get_input_chunk(ichunk) chunk = fwp.get_input_chunk(ichunk) delta = bc_chunk.input_data - chunk.input_data - assert np.allclose( - delta[..., 0], -2.72, atol=1e-03 - ), 'U reference offset is -1' - assert np.allclose( - delta[..., 1], 2.72, atol=1e-03 - ), 'V reference offset is 1' + assert np.allclose(delta[..., 0], -2.72, atol=1e-03), ( + 'U reference offset is -1' + ) + assert np.allclose(delta[..., 1], 2.72, atol=1e-03), ( + 'V reference offset is 1' + ) kwargs = { 'model_kwargs': strat.model_kwargs, diff --git a/tests/conftest.py b/tests/conftest.py index 50863f48e3..3f0d1c4a32 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -63,15 +63,9 @@ def set_random_state(): RANDOM_GENERATOR.bit_generator.state = GLOBAL_STATE -@pytest.fixture(autouse=True) -def train_on_cpu(): - """Train on cpu for tests.""" - os.environ['CUDA_VISIBLE_DEVICES'] = '-1' - - @pytest.fixture(scope='package') -def gen_config_with_concat_masked(): - """Get generator config with custom concat masked layer.""" +def gen_config_with_obs_2d(): + """Get generator config with observation layers.""" def func(): return [ @@ -149,6 +143,181 @@ def func(): return func +@pytest.fixture(scope='package') +def gen_config_with_obs_3d(): + """Get generator config with observation layers.""" + + def func(): + return [ + { + 'class': 'FlexiblePadding', + 'paddings': [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + 'mode': 'REFLECT', + }, + { + 'class': 'Conv3D', + 'filters': 2, + 'kernel_size': 3, + 'strides': 1, + 'activation': 'relu', + }, + {'class': 'Cropping3D', 'cropping': 2}, + { + 'class': 'SpatioTemporalExpansion', + 'temporal_mult': 2 + }, + { + 'class': 'FlexiblePadding', + 'paddings': [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + 'mode': 'REFLECT', + }, + { + 'class': 'Conv3D', + 'filters': 64, + 'kernel_size': 3, + 'strides': 1, + 'activation': 'relu', + }, + {'class': 'Cropping3D', 'cropping': 2}, + { + 'class': 'FlexiblePadding', + 'paddings': [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + 'mode': 'REFLECT', + }, + { + 'class': 'Conv3D', + 'filters': 64, + 'kernel_size': 3, + 'strides': 1, + 'activation': 'relu', + }, + {'class': 'Cropping3D', 'cropping': 2}, + { + 'class': 'SpatioTemporalExpansion', + 'spatial_mult': 2 + }, + {'class': 'Activation', 'activation': 'relu'}, + { + 'class': 'FlexiblePadding', + 'paddings': [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + 'mode': 'REFLECT', + }, + { + 'class': 'Conv3D', + 'filters': 2, + 'kernel_size': 3, + 'strides': 1, + 'activation': 'relu', + }, + {'class': 'Cropping3D', 'cropping': 2}, + {'class': 'Sup3rConcatObs', 'name': 'u_10m_obs'}, + {'class': 'Sup3rConcatObs', 'name': 'v_10m_obs'}, + { + 'class': 'FlexiblePadding', + 'paddings': [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + 'mode': 'REFLECT', + }, + { + 'class': 'Conv3D', + 'filters': 2, + 'kernel_size': 3, + 'strides': 1, + 'activation': 'relu', + }, + {'class': 'Cropping3D', 'cropping': 2}, + ] + + return func + + +@pytest.fixture(scope='package') +def gen_config_with_obs_3d_topo(): + """Get generator config with observation layers and topo layer.""" + + def func(): + return [ + { + 'class': 'FlexiblePadding', + 'paddings': [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + 'mode': 'REFLECT', + }, + { + 'class': 'Conv3D', + 'filters': 2, + 'kernel_size': 3, + 'strides': 1, + 'activation': 'relu', + }, + {'class': 'Cropping3D', 'cropping': 2}, + { + 'class': 'SpatioTemporalExpansion', + 'temporal_mult': 2 + }, + { + 'class': 'FlexiblePadding', + 'paddings': [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + 'mode': 'REFLECT', + }, + { + 'class': 'Conv3D', + 'filters': 64, + 'kernel_size': 3, + 'strides': 1, + 'activation': 'relu', + }, + {'class': 'Cropping3D', 'cropping': 2}, + { + 'class': 'FlexiblePadding', + 'paddings': [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + 'mode': 'REFLECT', + }, + { + 'class': 'Conv3D', + 'filters': 64, + 'kernel_size': 3, + 'strides': 1, + 'activation': 'relu', + }, + {'class': 'Cropping3D', 'cropping': 2}, + { + 'class': 'SpatioTemporalExpansion', + 'spatial_mult': 2 + }, + {'class': 'Activation', 'activation': 'relu'}, + { + 'class': 'FlexiblePadding', + 'paddings': [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + 'mode': 'REFLECT', + }, + { + 'class': 'Conv3D', + 'filters': 2, + 'kernel_size': 3, + 'strides': 1, + 'activation': 'relu', + }, + {'class': 'Cropping3D', 'cropping': 2}, + {'class': 'Sup3rConcatObs', 'name': 'u_10m_obs'}, + {'class': 'Sup3rConcatObs', 'name': 'v_10m_obs'}, + {'class': 'Sup3rConcat', 'name': 'topography'}, + { + 'class': 'FlexiblePadding', + 'paddings': [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + 'mode': 'REFLECT', + }, + { + 'class': 'Conv3D', + 'filters': 2, + 'kernel_size': 3, + 'strides': 1, + 'activation': 'relu', + }, + {'class': 'Cropping3D', 'cropping': 2}, + ] + + return func + + @pytest.fixture(scope='package') def gen_config_with_topo(): """Get generator config with custom topo layer.""" diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index 486b574bba..8a320aaa60 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -73,7 +73,7 @@ def test_fwp_nc_cc(): out_pattern=out_files, input_handler_name='DataHandlerNCforCC', pass_workers=None, - invert_uv=False + invert_uv=False, ) forward_pass = ForwardPass(strat) forward_pass.run(strat, node_index=0) @@ -173,7 +173,7 @@ def test_fwp_spatial_only(input_files): out_pattern=out_files, pass_workers=1, output_workers=1, - invert_uv=False + invert_uv=False, ) forward_pass = ForwardPass(strat) assert strat.output_workers == 1 @@ -227,7 +227,7 @@ def test_fwp_nc(input_files): }, out_pattern=out_files, pass_workers=1, - invert_uv=False + invert_uv=False, ) forward_pass = ForwardPass(strat) assert forward_pass.strategy.pass_workers == 1 @@ -443,14 +443,12 @@ def test_fwp_chunking(input_files): 'time_slice': time_slice, }, ) - data_chunked = np.zeros( - ( - shape[0] * s_enhance, - shape[1] * s_enhance, - raw_tsteps * t_enhance, - len(model.hr_out_features), - ) - ) + data_chunked = np.zeros(( + shape[0] * s_enhance, + shape[1] * s_enhance, + raw_tsteps * t_enhance, + len(model.hr_out_features), + )) handlerNC = DataHandler( input_files, FEATURES, target=target, shape=shape ) @@ -564,19 +562,25 @@ def test_fwp_multi_step_model(input_files): fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') s_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - s_model.meta['lr_features'] = ['u_100m', 'v_100m'] - s_model.meta['hr_out_features'] = ['u_100m', 'v_100m'] - assert s_model.s_enhance == 2 - assert s_model.t_enhance == 1 + s_model.set_model_params( + lr_features=['u_100m', 'v_100m'], + hr_out_features=['u_100m', 'v_100m'], + s_enhance=2, + t_enhance=1, + input_resolution={'spatial': '6km', 'temporal': '40min'}, + ) _ = s_model.generate(np.ones((4, 10, 10, 2))) fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - st_model.meta['lr_features'] = ['u_100m', 'v_100m'] - st_model.meta['hr_out_features'] = ['u_100m', 'v_100m'] - assert st_model.s_enhance == 3 - assert st_model.t_enhance == 4 + st_model.set_model_params( + lr_features=['u_100m', 'v_100m'], + hr_out_features=['u_100m', 'v_100m'], + s_enhance=3, + t_enhance=4, + input_resolution={'spatial': '3km', 'temporal': '40min'}, + ) _ = st_model.generate(np.ones((4, 10, 10, 6, 2))) with tempfile.TemporaryDirectory() as td: diff --git a/tests/forward_pass/test_forward_pass_exo.py b/tests/forward_pass/test_forward_pass_exo.py index 9ccde0e106..436b30807d 100644 --- a/tests/forward_pass/test_forward_pass_exo.py +++ b/tests/forward_pass/test_forward_pass_exo.py @@ -54,25 +54,23 @@ def test_fwp_multi_step_model_topo_exoskip(input_files): fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') s1_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - s1_model.meta['lr_features'] = ['u_100m', 'v_100m', 'topography'] - s1_model.meta['hr_out_features'] = ['u_100m', 'v_100m'] - s1_model.meta['s_enhance'] = 2 - s1_model.meta['t_enhance'] = 1 - s1_model.meta['input_resolution'] = { - 'spatial': '48km', - 'temporal': '60min', - } + s1_model.set_model_params( + lr_features=['u_100m', 'v_100m', 'topography'], + hr_out_features=['u_100m', 'v_100m'], + s_enhance=2, + t_enhance=1, + input_resolution={'spatial': '48km', 'temporal': '60min'}, + ) _ = s1_model.generate(np.ones((4, 10, 10, 3))) s2_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - s2_model.meta['lr_features'] = ['u_100m', 'v_100m', 'topography'] - s2_model.meta['hr_out_features'] = ['u_100m', 'v_100m'] - s2_model.meta['s_enhance'] = 2 - s2_model.meta['t_enhance'] = 1 - s2_model.meta['input_resolution'] = { - 'spatial': '24km', - 'temporal': '60min', - } + s2_model.set_model_params( + lr_features=['u_100m', 'v_100m', 'topography'], + hr_out_features=['u_100m', 'v_100m'], + s_enhance=2, + t_enhance=1, + input_resolution={'spatial': '24km', 'temporal': '60min'}, + ) _ = s2_model.generate(np.ones((4, 10, 10, 3))) fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') @@ -281,14 +279,13 @@ def test_fwp_multi_step_model_topo_noskip(input_files): fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - st_model.meta['lr_features'] = ['u_100m', 'v_100m', 'topography'] - st_model.meta['hr_out_features'] = ['u_100m', 'v_100m'] - st_model.meta['s_enhance'] = 3 - st_model.meta['t_enhance'] = 4 - st_model.meta['input_resolution'] = { - 'spatial': '12km', - 'temporal': '60min', - } + st_model.set_model_params( + lr_features=['u_100m', 'v_100m', 'topography'], + hr_out_features=['u_100m', 'v_100m'], + s_enhance=3, + t_enhance=4, + input_resolution={'spatial': '12km', 'temporal': '60min'}, + ) _ = st_model.generate(np.ones((4, 10, 10, 6, 3))) with tempfile.TemporaryDirectory() as td: @@ -479,6 +476,7 @@ def test_fwp_single_step_wind_hi_res_topo(input_files, plot=False): model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4) model.meta['lr_features'] = ['u_100m', 'v_100m', 'topography'] model.meta['hr_out_features'] = ['u_100m', 'v_100m'] + model.meta['hr_exo_features'] = ['topography'] model.meta['s_enhance'] = 2 model.meta['t_enhance'] = 2 model.meta['input_resolution'] = {'spatial': '8km', 'temporal': '60min'} @@ -561,6 +559,7 @@ def test_fwp_multi_step_wind_hi_res_topo(input_files, gen_config_with_topo): gen_config_with_topo('Sup3rConcat'), fp_disc, learning_rate=1e-4 ) s1_model.meta['lr_features'] = ['u_100m', 'v_100m', 'topography'] + s1_model.meta['hr_exo_features'] = ['topography'] s1_model.meta['hr_out_features'] = ['u_100m', 'v_100m'] s1_model.meta['s_enhance'] = 2 s1_model.meta['t_enhance'] = 1 @@ -586,6 +585,7 @@ def test_fwp_multi_step_wind_hi_res_topo(input_files, gen_config_with_topo): gen_config_with_topo('Sup3rConcat'), fp_disc, learning_rate=1e-4 ) s2_model.meta['lr_features'] = ['u_100m', 'v_100m', 'topography'] + s2_model.meta['hr_exo_features'] = ['topography'] s2_model.meta['hr_out_features'] = ['u_100m', 'v_100m'] s2_model.meta['s_enhance'] = 2 s2_model.meta['t_enhance'] = 1 @@ -663,6 +663,7 @@ def test_fwp_wind_hi_res_topo_plus_linear(input_files, gen_config_with_topo): gen_config_with_topo('Sup3rConcat'), fp_disc, learning_rate=1e-4 ) s_model.meta['lr_features'] = ['u_100m', 'v_100m', 'topography'] + s_model.meta['hr_exo_features'] = ['topography'] s_model.meta['hr_out_features'] = ['u_100m', 'v_100m'] s_model.meta['s_enhance'] = 2 s_model.meta['t_enhance'] = 1 @@ -896,6 +897,7 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza( gen_config_with_topo('Sup3rConcat'), fp_disc, learning_rate=1e-4 ) s1_model.meta['lr_features'] = ['u_100m', 'v_100m', 'topography', 'sza'] + s1_model.meta['hr_exo_features'] = ['topography'] s1_model.meta['hr_out_features'] = ['u_100m', 'v_100m'] s1_model.meta['s_enhance'] = 2 s1_model.meta['t_enhance'] = 1 @@ -929,6 +931,7 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza( gen_config_with_topo('Sup3rConcat'), fp_disc, learning_rate=1e-4 ) s2_model.meta['lr_features'] = ['u_100m', 'v_100m', 'topography', 'sza'] + s2_model.meta['hr_exo_features'] = ['topography'] s2_model.meta['hr_out_features'] = ['u_100m', 'v_100m'] s2_model.meta['s_enhance'] = 2 s2_model.meta['t_enhance'] = 1 @@ -941,6 +944,7 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza( fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') st_model = Sup3rGan(gen_t_model, fp_disc, learning_rate=1e-4) st_model.meta['lr_features'] = ['u_100m', 'v_100m', 'sza'] + st_model.meta['hr_exo_features'] = ['sza'] st_model.meta['hr_out_features'] = ['u_100m', 'v_100m'] st_model.meta['s_enhance'] = 3 st_model.meta['t_enhance'] = 2 @@ -1037,8 +1041,14 @@ def test_solar_multistep_exo(gen_config_with_topo): model1 = Sup3rGan(fp_gen, fp_disc) _ = model1.generate(np.ones((4, 10, 10, len(features1)))) model1.set_norm_stats({'clearsky_ratio': 0.7}, {'clearsky_ratio': 0.04}) - model1.meta['input_resolution'] = {'spatial': '8km', 'temporal': '40min'} - model1.set_model_params(lr_features=features1, hr_out_features=features1) + model1.set_model_params( + lr_features=features1, + hr_out_features=features1, + hr_exo_features=[], + s_enhance=2, + t_enhance=1, + input_resolution={'spatial': '8km', 'temporal': '40min'}, + ) features2 = ['U_200m', 'V_200m', 'topography'] @@ -1069,6 +1079,8 @@ def test_solar_multistep_exo(gen_config_with_topo): lr_features=features2, hr_out_features=features2[:-1], hr_exo_features=features2[-1:], + s_enhance=2, + t_enhance=1, ) features_in_3 = ['clearsky_ratio', 'U_200m', 'V_200m'] @@ -1081,9 +1093,12 @@ def test_solar_multistep_exo(gen_config_with_topo): {'U_200m': 4.2, 'V_200m': 5.6, 'clearsky_ratio': 0.7}, {'U_200m': 1.1, 'V_200m': 1.3, 'clearsky_ratio': 0.04}, ) - model3.meta['input_resolution'] = {'spatial': '2km', 'temporal': '40min'} model3.set_model_params( - lr_features=features_in_3, hr_out_features=features_out_3 + lr_features=features_in_3, + hr_out_features=features_out_3, + s_enhance=1, + t_enhance=8, + input_resolution={'spatial': '2km', 'temporal': '40min'}, ) with tempfile.TemporaryDirectory() as td: diff --git a/tests/forward_pass/test_forward_pass_obs.py b/tests/forward_pass/test_forward_pass_obs.py index cd56ce70b2..333c38a3da 100644 --- a/tests/forward_pass/test_forward_pass_obs.py +++ b/tests/forward_pass/test_forward_pass_obs.py @@ -9,7 +9,7 @@ import pytest from rex import Outputs -from sup3r.models import Sup3rGanWithObs +from sup3r.models import Sup3rGan from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy from sup3r.utilities.pytest.helpers import make_fake_dset from sup3r.utilities.utilities import RANDOM_GENERATOR @@ -108,27 +108,24 @@ def h5_obs_file(tmpdir_factory): @pytest.mark.parametrize('obs_file', ['nc_obs_file', 'h5_obs_file']) -def test_fwp_with_obs( - input_file, obs_file, gen_config_with_concat_masked, request -): +def test_fwp_with_obs(input_file, obs_file, gen_config_with_obs_2d, request): """Test a special model trained to condition output on input observations.""" obs_file = request.getfixturevalue(obs_file) - Sup3rGanWithObs.seed() - - model = Sup3rGanWithObs( - gen_config_with_concat_masked(), - pytest.S_FP_DISC, - onshore_obs_frac={'spatial': 0.1}, - loss_obs_weight=0.1, - learning_rate=1e-4, + Sup3rGan.seed() + + model = Sup3rGan( + gen_config_with_obs_2d(), pytest.S_FP_DISC, learning_rate=1e-4 + ) + model.set_model_params( + input_resolution={'spatial': '16km', 'temporal': '3600min'}, + lr_features=['u_10m', 'v_10m'], + hr_exo_features=['u_10m_obs', 'v_10m_obs'], + hr_out_features=['u_10m', 'v_10m'], + s_enhance=2, + t_enhance=1, ) - model.meta['input_resolution'] = {'spatial': '16km', 'temporal': '3600min'} - model.meta['lr_features'] = ['u_10m', 'v_10m'] - model.meta['hr_out_features'] = ['u_10m', 'v_10m'] - model.meta['s_enhance'] = 2 - model.meta['t_enhance'] = 1 with tempfile.TemporaryDirectory() as td: exo_tmp = { @@ -151,10 +148,7 @@ def test_fwp_with_obs( ] }, } - _ = model.generate( - np.ones((6, 10, 10, 2)), - exogenous_data=exo_tmp - ) + _ = model.generate(np.ones((6, 10, 10, 2)), exogenous_data=exo_tmp) model_dir = os.path.join(td, 'test') model.save(model_dir) @@ -182,7 +176,7 @@ def test_fwp_with_obs( handler = ForwardPassStrategy( input_file, model_kwargs=model_kwargs, - model_class='Sup3rGanWithObs', + model_class='Sup3rGan', fwp_chunk_shape=fwp_chunk_shape, input_handler_kwargs=input_handler_kwargs, spatial_pad=0, diff --git a/tests/forward_pass/test_multi_step.py b/tests/forward_pass/test_multi_step.py index f77b0258f0..716b97321a 100644 --- a/tests/forward_pass/test_multi_step.py +++ b/tests/forward_pass/test_multi_step.py @@ -33,10 +33,20 @@ def test_multi_step_model(features): model1 = Sup3rGan(fp_gen1, fp_disc) model2 = Sup3rGan(fp_gen2, fp_disc) - model1.meta['input_resolution'] = {'spatial': '27km', 'temporal': '64min'} - model2.meta['input_resolution'] = {'spatial': '9km', 'temporal': '16min'} - model1.set_model_params(lr_features=FEATURES, hr_out_features=FEATURES) - model2.set_model_params(lr_features=features, hr_out_features=features) + model1.set_model_params( + lr_features=FEATURES, + hr_out_features=FEATURES, + input_resolution={'spatial': '27km', 'temporal': '64min'}, + s_enhance=3, + t_enhance=4, + ) + model2.set_model_params( + lr_features=features, + hr_out_features=features, + input_resolution={'spatial': '9km', 'temporal': '16min'}, + s_enhance=3, + t_enhance=4, + ) _ = model1.generate(np.ones((4, 10, 10, 6, len(FEATURES)))) _ = model2.generate(np.ones((4, 10, 10, 6, len(features)))) @@ -93,12 +103,27 @@ def test_multi_step_norm(norm_option): {'u_100m': 0.1, 'v_100m': 0.8}, {'u_100m': 0.04, 'v_100m': 0.02} ) - model1.meta['input_resolution'] = {'spatial': '27km', 'temporal': '64min'} - model2.meta['input_resolution'] = {'spatial': '9km', 'temporal': '16min'} - model3.meta['input_resolution'] = {'spatial': '3km', 'temporal': '4min'} - model1.set_model_params(lr_features=FEATURES, hr_out_features=FEATURES) - model2.set_model_params(lr_features=FEATURES, hr_out_features=FEATURES) - model3.set_model_params(lr_features=FEATURES, hr_out_features=FEATURES) + model1.set_model_params( + lr_features=FEATURES, + hr_out_features=FEATURES, + input_resolution={'spatial': '27km', 'temporal': '64min'}, + s_enhance=3, + t_enhance=4, + ) + model2.set_model_params( + lr_features=FEATURES, + hr_out_features=FEATURES, + input_resolution={'spatial': '9km', 'temporal': '16min'}, + s_enhance=3, + t_enhance=4, + ) + model3.set_model_params( + lr_features=FEATURES, + hr_out_features=FEATURES, + input_resolution={'spatial': '3km', 'temporal': '4min'}, + s_enhance=3, + t_enhance=4, + ) _ = model1.generate(np.ones((4, 10, 10, 6, len(FEATURES)))) _ = model2.generate(np.ones((4, 10, 10, 6, len(FEATURES)))) @@ -146,11 +171,20 @@ def test_spatial_then_temporal_gan(): {'u_100m': 0.3, 'v_100m': 0.9}, {'u_100m': 0.02, 'v_100m': 0.07} ) - model1.meta['input_resolution'] = {'spatial': '12km', 'temporal': '40min'} - model2.meta['input_resolution'] = {'spatial': '6km', 'temporal': '40min'} - - model1.set_model_params(lr_features=FEATURES, hr_out_features=FEATURES) - model2.set_model_params(lr_features=FEATURES, hr_out_features=FEATURES) + model1.set_model_params( + lr_features=FEATURES, + hr_out_features=FEATURES, + input_resolution={'spatial': '12km', 'temporal': '40min'}, + s_enhance=2, + t_enhance=1, + ) + model2.set_model_params( + lr_features=FEATURES, + hr_out_features=FEATURES, + input_resolution={'spatial': '6km', 'temporal': '40min'}, + s_enhance=3, + t_enhance=4, + ) with tempfile.TemporaryDirectory() as td: fp1 = os.path.join(td, 'model1') @@ -184,11 +218,20 @@ def test_temporal_then_spatial_gan(): {'u_100m': 0.3, 'v_100m': 0.9}, {'u_100m': 0.02, 'v_100m': 0.07} ) - model1.meta['input_resolution'] = {'spatial': '12km', 'temporal': '40min'} - model2.meta['input_resolution'] = {'spatial': '6km', 'temporal': '40min'} - - model1.set_model_params(lr_features=FEATURES, hr_out_features=FEATURES) - model2.set_model_params(lr_features=FEATURES, hr_out_features=FEATURES) + model1.set_model_params( + lr_features=FEATURES, + hr_out_features=FEATURES, + input_resolution={'spatial': '12km', 'temporal': '40min'}, + s_enhance=2, + t_enhance=1, + ) + model2.set_model_params( + lr_features=FEATURES, + hr_out_features=FEATURES, + input_resolution={'spatial': '6km', 'temporal': '40min'}, + s_enhance=3, + t_enhance=4, + ) with tempfile.TemporaryDirectory() as td: fp1 = os.path.join(td, 'model1') @@ -215,8 +258,13 @@ def test_spatial_gan_then_linear_interp(): model1.set_norm_stats( {'u_100m': 0.1, 'v_100m': 0.2}, {'u_100m': 0.04, 'v_100m': 0.02} ) - model1.meta['input_resolution'] = {'spatial': '12km', 'temporal': '60min'} - model1.set_model_params(lr_features=FEATURES, hr_out_features=FEATURES) + model1.set_model_params( + lr_features=FEATURES, + hr_out_features=FEATURES, + input_resolution={'spatial': '12km', 'temporal': '60min'}, + s_enhance=2, + t_enhance=1, + ) with tempfile.TemporaryDirectory() as td: fp1 = os.path.join(td, 'model1') @@ -241,8 +289,13 @@ def test_solar_multistep(): model1 = Sup3rGan(fp_gen, fp_disc) _ = model1.generate(np.ones((4, 10, 10, len(features1)))) model1.set_norm_stats({'clearsky_ratio': 0.7}, {'clearsky_ratio': 0.04}) - model1.meta['input_resolution'] = {'spatial': '8km', 'temporal': '40min'} - model1.set_model_params(lr_features=features1, hr_out_features=features1) + model1.set_model_params( + lr_features=features1, + hr_out_features=features1, + s_enhance=2, + t_enhance=1, + input_resolution={'spatial': '8km', 'temporal': '40min'}, + ) features2 = ['U_200m', 'V_200m'] fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') @@ -252,8 +305,13 @@ def test_solar_multistep(): model2.set_norm_stats( {'U_200m': 4.2, 'V_200m': 5.6}, {'U_200m': 1.1, 'V_200m': 1.3} ) - model2.meta['input_resolution'] = {'spatial': '4km', 'temporal': '40min'} - model2.set_model_params(lr_features=features2, hr_out_features=features2) + model2.set_model_params( + lr_features=features2, + hr_out_features=features2, + s_enhance=2, + t_enhance=1, + input_resolution={'spatial': '4km', 'temporal': '40min'}, + ) features_in_3 = ['clearsky_ratio', 'U_200m', 'V_200m'] features_out_3 = ['clearsky_ratio'] @@ -265,9 +323,12 @@ def test_solar_multistep(): {'U_200m': 4.2, 'V_200m': 5.6, 'clearsky_ratio': 0.7}, {'U_200m': 1.1, 'V_200m': 1.3, 'clearsky_ratio': 0.04}, ) - model3.meta['input_resolution'] = {'spatial': '2km', 'temporal': '40min'} model3.set_model_params( - lr_features=features_in_3, hr_out_features=features_out_3 + lr_features=features_in_3, + hr_out_features=features_out_3, + s_enhance=1, + t_enhance=8, + input_resolution={'spatial': '2km', 'temporal': '40min'}, ) with tempfile.TemporaryDirectory() as td: diff --git a/tests/forward_pass/test_surface_model.py b/tests/forward_pass/test_surface_model.py index 5ea91720e7..9f6ae554b7 100644 --- a/tests/forward_pass/test_surface_model.py +++ b/tests/forward_pass/test_surface_model.py @@ -159,8 +159,8 @@ def test_multi_step_surface(s_enhance=2, t_enhance=2): temporal_model_kwargs={'model_dirs': temporal_dir}) for model in ms_model.models: - assert isinstance(model.s_enhance, int) - assert isinstance(model.t_enhance, int) + assert isinstance(model.s_enhance, (int, np.integer)) + assert isinstance(model.t_enhance, (int, np.integer)) x = np.ones((2, 10, 10, len(FEATURES))) with pytest.raises(AssertionError): diff --git a/tests/output/test_qa.py b/tests/output/test_qa.py index 9854591300..1b7f11ff81 100644 --- a/tests/output/test_qa.py +++ b/tests/output/test_qa.py @@ -55,10 +55,13 @@ def test_qa(input_files, ext): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 10, 10, 6, len(TRAIN_FEATURES)))) - model.meta['lr_features'] = TRAIN_FEATURES - model.meta['hr_out_features'] = MODEL_OUT_FEATURES - model.meta['s_enhance'] = 3 - model.meta['t_enhance'] = 4 + model.set_model_params( + lr_features=TRAIN_FEATURES, + hr_out_features=MODEL_OUT_FEATURES, + s_enhance=3, + t_enhance=4, + input_resolution={'spatial': '12km', 'temporal': '60min'}, + ) with tempfile.TemporaryDirectory() as td: out_dir = os.path.join(td, 'st_gan') model.save(out_dir) diff --git a/tests/pipeline/test_cli.py b/tests/pipeline/test_cli.py index a91d1d8c74..adfd8d68c3 100644 --- a/tests/pipeline/test_cli.py +++ b/tests/pipeline/test_cli.py @@ -201,20 +201,21 @@ def test_fwd_pass_with_bc_cli(runner, input_files): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) - model.meta['lr_features'] = FEATURES - model.meta['hr_out_features'] = FEATURES[:2] - assert model.s_enhance == 3 - assert model.t_enhance == 4 + model.set_model_params( + lr_features=FEATURES, + hr_out_features=FEATURES[:2], + s_enhance=3, + t_enhance=4, + input_resolution={'spatial': '12km', 'temporal': '60min'}, + ) with tempfile.TemporaryDirectory() as td: out_dir = os.path.join(td, 'st_gan') model.save(out_dir) - n_chunks = np.prod( - [ - int(np.ceil(ds / fs)) - for ds, fs in zip([*shape, data_shape[2]], fwp_chunk_shape) - ] - ) + n_chunks = np.prod([ + int(np.ceil(ds / fs)) + for ds, fs in zip([*shape, data_shape[2]], fwp_chunk_shape) + ]) out_files = os.path.join(td, 'out_{file_id}.nc') cache_pattern = os.path.join(td, 'cache_{feature}.nc') log_pattern = os.path.join(td, 'logs', 'log_{node_index}.log') @@ -299,20 +300,21 @@ def test_fwd_pass_cli(runner, input_files): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) - model.meta['lr_features'] = FEATURES - model.meta['hr_out_features'] = FEATURES[:2] - assert model.s_enhance == 3 - assert model.t_enhance == 4 + model.set_model_params( + lr_features=FEATURES, + hr_out_features=FEATURES[:2], + s_enhance=3, + t_enhance=4, + input_resolution={'spatial': '12km', 'temporal': '60min'}, + ) with tempfile.TemporaryDirectory() as td: out_dir = os.path.join(td, 'st_gan') model.save(out_dir) - n_chunks = np.prod( - [ - int(np.ceil(ds / fs)) - for ds, fs in zip([*shape, data_shape[2]], fwp_chunk_shape) - ] - ) + n_chunks = np.prod([ + int(np.ceil(ds / fs)) + for ds, fs in zip([*shape, data_shape[2]], fwp_chunk_shape) + ]) out_files = os.path.join(td, 'out_{file_id}.nc') cache_pattern = os.path.join(td, 'cache_{feature}.nc') log_pattern = os.path.join(td, 'logs', 'log_{node_index}.log') @@ -359,15 +361,17 @@ def test_pipeline_fwp_qa(runner, input_files): Sup3rGan.seed() model = Sup3rGan(pytest.ST_FP_GEN, pytest.ST_FP_DISC, learning_rate=1e-4) + model.set_model_params( + lr_features=FEATURES, + hr_out_features=FEATURES[:2], + s_enhance=3, + t_enhance=4, + input_resolution={'spatial': '12km', 'temporal': '60min'}, + ) input_resolution = {'spatial': '12km', 'temporal': '60min'} - model.meta['input_resolution'] = input_resolution assert model.input_resolution == input_resolution assert model.output_resolution == {'spatial': '4km', 'temporal': '15min'} _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) - model.meta['lr_features'] = FEATURES - model.meta['hr_out_features'] = FEATURES[:2] - assert model.s_enhance == 3 - assert model.t_enhance == 4 with tempfile.TemporaryDirectory() as td: out_dir = os.path.join(td, 'st_gan') diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index d05225bb2c..465fec9f92 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -43,13 +43,15 @@ def test_fwp_pipeline_with_bc(input_files): model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) input_resolution = {'spatial': '12km', 'temporal': '60min'} - model.meta['input_resolution'] = input_resolution + model.set_model_params( + input_resolution=input_resolution, + s_enhance=3, + t_enhance=4, + lr_features=FEATURES, + hr_out_features=FEATURES[:2], + ) assert model.input_resolution == input_resolution assert model.output_resolution == {'spatial': '4km', 'temporal': '15min'} - _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) - model.meta['lr_features'] = FEATURES - model.meta['hr_out_features'] = FEATURES[:2] - model.meta['hr_exo_features'] = FEATURES[2:] assert model.s_enhance == 3 assert model.t_enhance == 4 @@ -179,13 +181,15 @@ def test_fwp_pipeline(input_files): model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) input_resolution = {'spatial': '12km', 'temporal': '60min'} - model.meta['input_resolution'] = input_resolution + model.set_model_params( + input_resolution=input_resolution, + s_enhance=3, + t_enhance=4, + lr_features=FEATURES, + hr_out_features=FEATURES[:2], + ) assert model.input_resolution == input_resolution assert model.output_resolution == {'spatial': '4km', 'temporal': '15min'} - _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) - model.meta['lr_features'] = FEATURES - model.meta['hr_out_features'] = FEATURES[:2] - model.meta['hr_exo_features'] = FEATURES[2:] assert model.s_enhance == 3 assert model.t_enhance == 4 @@ -290,11 +294,13 @@ def test_fwp_pipeline_with_mask(input_files): model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) input_resolution = {'spatial': '12km', 'temporal': '60min'} - model.meta['input_resolution'] = input_resolution - _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) - model.meta['lr_features'] = FEATURES - model.meta['hr_out_features'] = FEATURES[:2] - model.meta['hr_exo_features'] = FEATURES[2:] + model.set_model_params( + input_resolution=input_resolution, + s_enhance=3, + t_enhance=4, + lr_features=FEATURES, + hr_out_features=FEATURES[:2], + ) test_context = click.Context(click.Command('pipeline'), obj={}) with tempfile.TemporaryDirectory() as td, test_context as ctx: @@ -391,13 +397,15 @@ def test_multiple_fwp_pipeline(input_files): model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) input_resolution = {'spatial': '12km', 'temporal': '60min'} - model.meta['input_resolution'] = input_resolution + model.set_model_params( + input_resolution=input_resolution, + s_enhance=3, + t_enhance=4, + lr_features=FEATURES, + hr_out_features=FEATURES[:2], + ) assert model.input_resolution == input_resolution assert model.output_resolution == {'spatial': '4km', 'temporal': '15min'} - _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) - model.meta['lr_features'] = FEATURES - model.meta['hr_out_features'] = FEATURES[:2] - model.meta['hr_exo_features'] = FEATURES[2:] assert model.s_enhance == 3 assert model.t_enhance == 4 diff --git a/tests/rasterizers/test_dual.py b/tests/rasterizers/test_dual.py index 725c2fdb05..6499096a2f 100644 --- a/tests/rasterizers/test_dual.py +++ b/tests/rasterizers/test_dual.py @@ -34,10 +34,10 @@ def test_dual_rasterizer_shapes(full_shape=(20, 20)): s_enhance=2, t_enhance=1, ) - assert pair_rasterizer.lr_data.shape == ( - pair_rasterizer.hr_data.shape[0] // 2, - pair_rasterizer.hr_data.shape[1] // 2, - *pair_rasterizer.hr_data.shape[2:], + assert pair_rasterizer.data.low_res.shape == ( + pair_rasterizer.data.high_res.shape[0] // 2, + pair_rasterizer.data.high_res.shape[1] // 2, + *pair_rasterizer.data.high_res.shape[2:], ) @@ -70,7 +70,7 @@ def test_dual_nan_fill(full_shape=(20, 20)): t_enhance=1, ) - assert not np.isnan(pair_rasterizer.lr_data.as_array()).any() + assert not np.isnan(pair_rasterizer.data.low_res.as_array()).any() def test_regrid_caching(full_shape=(20, 20)): @@ -110,9 +110,9 @@ def test_regrid_caching(full_shape=(20, 20)): assert np.array_equal( lr_container_new.data[FEATURES][...], - pair_rasterizer.lr_data[FEATURES][...], + pair_rasterizer.data.low_res[FEATURES][...], ) assert np.array_equal( hr_container_new.data[FEATURES][...], - pair_rasterizer.hr_data[FEATURES][...], + pair_rasterizer.data.high_res[FEATURES][...], ) diff --git a/tests/rasterizers/test_exo.py b/tests/rasterizers/test_exo.py index 20e00e4145..d6241db00c 100644 --- a/tests/rasterizers/test_exo.py +++ b/tests/rasterizers/test_exo.py @@ -23,7 +23,7 @@ TARGET = (13.67, 125.0) SHAPE = (8, 8) S_ENHANCE = [1, 4] -T_ENHANCE = [1, 1] +T_ENHANCE = [1, 2] def test_exo_data_init(): @@ -32,7 +32,16 @@ def test_exo_data_init(): ExoData(steps=['dummy']) -@pytest.mark.parametrize('feature', ['topography', 'sza']) +@pytest.mark.parametrize( + 'feature', + [ + 'topography', + 'sza', + 'latitude_feature', + 'longitude_feature', + 'time_feature', + ], +) def test_exo_cache(feature): """Test exogenous data caching and re-load""" # no cached data @@ -370,7 +379,8 @@ def test_obs_agg(s_enhance, with_nans): ) agg_obs = np.asarray(te._get_data_3d()) true_obs = ( - te.source_handler['u_10m'] + te + .source_handler['u_10m'] .coarsen({ 'south_north': 4 // s_enhance, 'west_east': 4 // s_enhance, diff --git a/tests/samplers/test_feature_sets.py b/tests/samplers/test_feature_sets.py index 7c87bb72b6..8eafad85d4 100644 --- a/tests/samplers/test_feature_sets.py +++ b/tests/samplers/test_feature_sets.py @@ -8,31 +8,58 @@ @pytest.mark.parametrize( - ['features', 'lr_only_features', 'hr_exo_features'], + ['features', 'lr_features', 'hr_exo_features', 'hr_out_features'], [ - (['V_100m'], ['V_100m'], []), - (['U_100m'], ['V_100m'], ['V_100m']), - (['U_100m'], [], ['U_100m']), - (['U_100m', 'V_100m'], [], ['U_100m']), - (['U_100m', 'V_100m'], [], ['V_100m', 'U_100m']), + (['V_100m'], ['V_100m'], [], ['U_100m']), + (['U_100m'], ['V_100m'], ['V_100m'], []), ], ) -def test_feature_errors(features, lr_only_features, hr_exo_features): +def test_feature_errors( + features, lr_features, hr_exo_features, hr_out_features +): """Each of these feature combinations should raise an error due to no features left in hr output or bad ordering""" + with pytest.raises((RuntimeError, AssertionError)): + _ = Sampler( + DummyData(data_shape=(20, 20, 10), features=features), + sample_shape=(5, 5, 4), + feature_sets={ + 'lr_features': lr_features, + 'hr_exo_features': hr_exo_features, + 'hr_out_features': hr_out_features, + }, + ) + + +@pytest.mark.parametrize( + ['lr_features', 'hr_exo_features', 'hr_out_features'], + [ + (['V_100m', 'topography'], ['topography'], ['V_100m_obs']), + ( + ['V_100m', 'topography'], + ['topography', 'V_100m_obs'], + ['V_100m_obs'], + ), + ], +) +def test_sampler_feature_sets(lr_features, hr_exo_features, hr_out_features): + """Each of these feature combinations should pass without raising an + error.""" + feats = set(lr_features) | set(hr_exo_features) | set(hr_out_features) sampler = Sampler( - DummyData(data_shape=(20, 20, 10), features=features), + DummyData(data_shape=(20, 20, 10), features=feats), sample_shape=(5, 5, 4), feature_sets={ - 'lr_only_features': lr_only_features, + 'lr_features': lr_features, 'hr_exo_features': hr_exo_features, + 'hr_out_features': hr_out_features, }, ) - with pytest.raises((RuntimeError, AssertionError)): - _ = sampler.lr_features - _ = sampler.hr_out_features - _ = sampler.hr_exo_features + _ = sampler.lr_features + _ = sampler.hr_out_features + _ = sampler.hr_exo_features + _ = sampler.obs_features @pytest.mark.parametrize( @@ -87,7 +114,7 @@ def test_mixed_lr_hr_features(lr_features, hr_features, hr_exo_features): @pytest.mark.parametrize( - ['features', 'lr_only_features', 'hr_exo_features'], + ['lr_features', 'hr_exo_features', 'hr_out_features'], [ ( [ @@ -102,8 +129,8 @@ def test_mixed_lr_hr_features(lr_features, hr_features, hr_exo_features): 'dewpoint_temperature', 'topography', ], - ['pressure', 'kx', 'dewpoint_temperature'], ['topography'], + ['u_10m', 'u_100m', 'u_200m', 'u_80m', 'u_120m', 'u_140m'], ), ( [ @@ -119,8 +146,8 @@ def test_mixed_lr_hr_features(lr_features, hr_features, hr_exo_features): 'topography', 'srl', ], - ['pressure', 'kx', 'dewpoint_temperature'], ['topography', 'srl'], + ['u_10m', 'u_100m', 'u_200m', 'u_80m', 'u_120m', 'u_140m'], ), ( [ @@ -134,8 +161,8 @@ def test_mixed_lr_hr_features(lr_features, hr_features, hr_exo_features): 'kx', 'dewpoint_temperature', ], - ['pressure', 'kx', 'dewpoint_temperature'], [], + ['u_10m', 'u_100m', 'u_200m', 'u_80m', 'u_120m', 'u_140m'], ), ( [ @@ -148,12 +175,12 @@ def test_mixed_lr_hr_features(lr_features, hr_features, hr_exo_features): 'topography', 'srl', ], - [], ['topography', 'srl'], + ['u_10m', 'u_100m', 'u_200m', 'u_80m', 'u_120m', 'u_140m'], ), ], ) -def test_dual_feature_sets(features, lr_only_features, hr_exo_features): +def test_dual_feature_sets(lr_features, hr_exo_features, hr_out_features): """Each of these feature combinations should work fine with the dual sampler""" @@ -161,21 +188,136 @@ def test_dual_feature_sets(features, lr_only_features, hr_exo_features): lr_containers = [ DummyData( data_shape=(10, 10, 20), - features=[f.lower() for f in features], + features=[f.lower() for f in lr_features], + ), + DummyData( + data_shape=(12, 12, 15), + features=[f.lower() for f in lr_features], + ), + ] + hr_containers = [ + DummyData( + data_shape=(20, 20, 40), + features=[f.lower() for f in lr_features], + ), + DummyData( + data_shape=(24, 24, 30), + features=[f.lower() for f in lr_features], + ), + ] + sampler_pairs = [ + DualSampler( + Sup3rDataset(low_res=lr.data, high_res=hr.data), + hr_sample_shape, + s_enhance=2, + t_enhance=2, + feature_sets={ + 'lr_features': lr_features, + 'hr_exo_features': hr_exo_features, + 'hr_out_features': hr_out_features, + }, + ) + for lr, hr in zip(lr_containers, hr_containers) + ] + + for pair in sampler_pairs: + _ = pair.lr_features + _ = pair.hr_out_features + _ = pair.hr_exo_features + + +@pytest.mark.parametrize( + ['lr_features', 'hr_exo_features', 'hr_out_features'], + [ + ( + [ + 'u_10m', + 'u_100m', + 'u_200m', + 'u_80m', + 'u_120m', + 'u_140m', + 'pressure', + 'kx', + 'dewpoint_temperature', + 'topography', + ], + ['topography', 'u_100m_obs', 'v_100m_obs'], + ['u_10m', 'u_100m', 'u_200m', 'u_80m', 'u_120m', 'u_140m'], + ), + ( + [ + 'u_10m', + 'u_100m', + 'u_200m', + 'u_80m', + 'u_120m', + 'u_140m', + 'pressure', + 'kx', + 'dewpoint_temperature', + 'topography', + 'srl', + ], + ['topography', 'srl', 'u_100m_obs', 'v_100m_obs'], + ['u_10m', 'u_100m', 'u_200m', 'u_80m', 'u_120m', 'u_140m'], + ), + ( + [ + 'u_10m', + 'u_100m', + 'u_200m', + 'u_80m', + 'u_120m', + 'u_140m', + 'pressure', + 'kx', + 'dewpoint_temperature', + ], + [], + ['u_10m', 'u_100m', 'u_200m', 'u_80m', 'u_120m', 'u_140m'], + ), + ( + [ + 'u_10m', + 'u_100m', + 'u_200m', + 'u_80m', + 'u_120m', + 'u_140m', + 'topography', + 'srl', + ], + ['topography', 'srl', 'u_100m_obs', 'v_100m_obs'], + ['u_10m', 'u_100m', 'u_200m', 'u_80m', 'u_120m', 'u_140m'], + ), + ], +) +def test_dual_feature_sets_with_obs( + lr_features, hr_exo_features, hr_out_features +): + """Each of these feature combinations should work fine with the dual + sampler when obs features are included""" + + hr_sample_shape = (8, 8, 10) + lr_containers = [ + DummyData( + data_shape=(10, 10, 20), + features=[f.lower() for f in lr_features], ), DummyData( data_shape=(12, 12, 15), - features=[f.lower() for f in features], + features=[f.lower() for f in lr_features], ), ] hr_containers = [ DummyData( data_shape=(20, 20, 40), - features=[f.lower() for f in features], + features=[f.lower() for f in lr_features], ), DummyData( data_shape=(24, 24, 30), - features=[f.lower() for f in features], + features=[f.lower() for f in lr_features], ), ] sampler_pairs = [ @@ -184,10 +326,12 @@ def test_dual_feature_sets(features, lr_only_features, hr_exo_features): hr_sample_shape, s_enhance=2, t_enhance=2, + proxy_obs_kwargs={'onshore_obs_frac': {'spatial': 0.1}}, feature_sets={ - 'features': features, - 'lr_only_features': lr_only_features, - 'hr_exo_features': hr_exo_features}, + 'lr_features': lr_features, + 'hr_exo_features': hr_exo_features, + 'hr_out_features': hr_out_features, + }, ) for lr, hr in zip(lr_containers, hr_containers) ] @@ -196,3 +340,4 @@ def test_dual_feature_sets(features, lr_only_features, hr_exo_features): _ = pair.lr_features _ = pair.hr_out_features _ = pair.hr_exo_features + _ = pair.obs_features diff --git a/tests/samplers/test_with_obs.py b/tests/samplers/test_with_obs.py new file mode 100644 index 0000000000..38cda97758 --- /dev/null +++ b/tests/samplers/test_with_obs.py @@ -0,0 +1,154 @@ +"""Test sampler behavior with proxy observations.""" + +import numpy as np +import pytest + +from sup3r.preprocessing import DualSampler, Sampler +from sup3r.preprocessing.base import Sup3rDataset +from sup3r.utilities.pytest.helpers import DummyData + +LR_FEATURES = ['u_100m', 'v_100m', 'temperature_2m'] +HR_OUT_FEATURES = ['u_100m', 'v_100m'] +OBS_FEATURES = ['u_100m_obs', 'v_100m_obs'] + + +def _make_sampler( + sampler_cls, + hr_shape, + sample_shape, + batch_size, + proxy_obs_kwargs, + hr_features=None, +): + """Create either Sampler or DualSampler with proxy obs feature sets.""" + hr_features = hr_features or LR_FEATURES + feature_sets = { + 'lr_features': LR_FEATURES, + 'hr_out_features': HR_OUT_FEATURES, + 'hr_exo_features': [ + *[f for f in hr_features if f == 'topography'], + *OBS_FEATURES, + ], + } + + if sampler_cls is Sampler: + data = DummyData(data_shape=hr_shape, features=hr_features) + return Sampler( + data, + sample_shape=sample_shape, + batch_size=batch_size, + proxy_obs_kwargs=proxy_obs_kwargs, + feature_sets=feature_sets, + ) + + lr_shape = (hr_shape[0] // 2, hr_shape[1] // 2, hr_shape[2]) + lr = DummyData(data_shape=lr_shape, features=LR_FEATURES).data.high_res + hr = DummyData(data_shape=hr_shape, features=hr_features).data.high_res + data = Sup3rDataset(low_res=lr, high_res=hr) + return DualSampler( + data, + sample_shape=sample_shape, + batch_size=batch_size, + s_enhance=2, + t_enhance=1, + proxy_obs_kwargs=proxy_obs_kwargs, + feature_sets=feature_sets, + ) + + +def _get_hr_batch(sampler): + """Extract the high-res batch for Sampler and DualSampler outputs.""" + batch = next(sampler) + return batch[-1] if isinstance(batch, tuple) else batch + + +@pytest.mark.parametrize('sampler_cls', [Sampler, DualSampler]) +@pytest.mark.parametrize( + 'sample_shape, obs_fracs, expected', + [ + ((30, 30, 1), {'spatial': 0.4, 'time': 1.0}, 0.4), + ((30, 30, 12), {'spatial': 0.4, 'time': 0.5}, 0.2), + ], +) +def test_proxy_obs_appended_and_fraction( + sampler_cls, sample_shape, obs_fracs, expected +): + """Proxy obs channels are appended and sampled at configured fraction.""" + sampler = _make_sampler( + sampler_cls=sampler_cls, + hr_shape=(60, 60, 500), + sample_shape=sample_shape, + batch_size=20, + proxy_obs_kwargs={'onshore_obs_frac': obs_fracs}, + ) + + batch = _get_hr_batch(sampler) + obs = batch[..., -2:] + + expected_channels = 5 if sampler_cls is Sampler else 4 + assert batch.shape[-1] == expected_channels + assert obs.shape[-1] == 2 + + observed_frac = np.isfinite(obs[..., 0]).mean() + assert np.isclose(observed_frac, expected, atol=0.05) + + +@pytest.mark.parametrize('sampler_cls', [Sampler, DualSampler]) +def test_proxy_obs_fraction_bounds_with_ranges(sampler_cls): + """Observed fraction stays within expected range for sampled fractions.""" + s_range = [0.1, 0.3] + t_range = [0.2, 0.6] + sampler = _make_sampler( + sampler_cls=sampler_cls, + hr_shape=(80, 80, 500), + sample_shape=(40, 40, 20), + batch_size=8, + proxy_obs_kwargs={ + 'onshore_obs_frac': {'spatial': s_range, 'time': t_range} + }, + ) + + batch = _get_hr_batch(sampler) + obs = batch[..., -2:] + observed_by_sample = np.isfinite(obs[..., 0]).mean(axis=(1, 2, 3)) + + lower = s_range[0] * t_range[0] + upper = s_range[1] * t_range[1] + assert np.all(observed_by_sample >= (lower - 0.02)) + assert np.all(observed_by_sample <= (upper + 0.02)) + + +@pytest.mark.parametrize('sampler_cls', [Sampler, DualSampler]) +def test_proxy_obs_onshore_offshore_topography_fractions(sampler_cls): + """Onshore and offshore obs fractions are applied by topography mask.""" + sampler = _make_sampler( + sampler_cls=sampler_cls, + hr_shape=(80, 80, 500), + sample_shape=(40, 40, 12), + batch_size=8, + proxy_obs_kwargs={ + 'onshore_obs_frac': {'spatial': 0.8, 'time': 1.0}, + 'offshore_obs_frac': {'spatial': 0.1, 'time': 1.0}, + }, + hr_features=[*LR_FEATURES, 'topography'], + ) + + topo_var = sampler.data.high_res['topography'] + topo = np.ones(topo_var.shape, dtype=np.float32) + topo[:, : topo.shape[1] // 2, :] = -1.0 + sampler.data.high_res['topography'] = (topo_var.dims, topo) + + batch = _get_hr_batch(sampler) + topo_idx = sampler.hr_source_features.index('topography') + topo_sample = batch[..., topo_idx] + obs = batch[..., -2:] + + onshore = topo_sample > 0 + offshore = ~onshore + + onshore_frac = np.isfinite(obs[..., 0][onshore]).mean() + offshore_frac = np.isfinite(obs[..., 0][offshore]).mean() + + assert np.isclose(onshore_frac, 0.8, atol=0.12) + assert np.isclose(offshore_frac, 0.1, atol=0.08) + assert onshore_frac > offshore_frac diff --git a/tests/training/test_end_to_end.py b/tests/training/test_end_to_end.py index 597a2d17b0..a81d647857 100644 --- a/tests/training/test_end_to_end.py +++ b/tests/training/test_end_to_end.py @@ -59,7 +59,7 @@ def test_end_to_end(): train_containers=[train_dh], val_containers=[val_dh], n_batches=2, - batch_size=10, + batch_size=4, sample_shape=(12, 12, 16), s_enhance=3, t_enhance=4, diff --git a/tests/training/test_train_conditional_exo.py b/tests/training/test_train_conditional_exo.py index 60e61d96bd..84b9697723 100644 --- a/tests/training/test_train_conditional_exo.py +++ b/tests/training/test_train_conditional_exo.py @@ -170,7 +170,10 @@ def test_wind_non_cc_hi_res_st_topo_mom2( lower_models={1: model_mom1}, n_batches=n_batches, sample_shape=(12, 12, 24), - feature_sets={'hr_exo_features': ['topography']}, + feature_sets={ + 'lr_features': ['u_100m', 'v_100m', 'topography'], + 'hr_out_features': ['u_100m', 'v_100m'], + 'hr_exo_features': ['topography']}, mode='eager' ) diff --git a/tests/training/test_train_conditioned_obs.py b/tests/training/test_train_conditioned_obs.py deleted file mode 100644 index 9cd8d35720..0000000000 --- a/tests/training/test_train_conditioned_obs.py +++ /dev/null @@ -1,102 +0,0 @@ -"""Test the training of super resolution GANs with exogenous observation -data.""" - -import os -import tempfile - -import numpy as np -import pytest - -from sup3r.models import Sup3rGanWithObs -from sup3r.preprocessing import ( - BatchHandler, - DataHandler, -) -from sup3r.utilities.utilities import RANDOM_GENERATOR - -SHAPE = (20, 20) -FEATURES_W = ['u_10m', 'v_10m'] -TARGET_W = (39.01, -105.15) - - -@pytest.mark.parametrize('gen_config', ['gen_config_with_concat_masked']) -def test_fixed_wind_obs(gen_config, request): - """Test a special model which fixes observations mid network with - ``Sup3rConcatObs`` layer.""" - - gen_config = request.getfixturevalue(gen_config)() - kwargs = { - 'file_paths': pytest.FP_WTK, - 'features': FEATURES_W, - 'target': TARGET_W, - 'shape': SHAPE, - } - - train_handler = DataHandler(**kwargs, time_slice=slice(None, 3000, 10)) - - val_handler = DataHandler(**kwargs, time_slice=slice(3000, None, 10)) - batcher = BatchHandler( - [train_handler], - [val_handler], - batch_size=2, - n_batches=1, - s_enhance=2, - t_enhance=1, - sample_shape=(20, 20, 1), - ) - - Sup3rGanWithObs.seed() - - model = Sup3rGanWithObs( - gen_config, - pytest.S_FP_DISC, - onshore_obs_frac={'spatial': 0.1}, - loss_obs_weight=0.1, - learning_rate=1e-4, - ) - model.meta['hr_out_features'] = ['u_10m', 'v_10m'] - test_mask = model._get_full_obs_mask(np.zeros((1, 20, 20, 1, 1))).numpy() - frac = 1 - test_mask.sum() / test_mask.size - assert np.abs(0.1 - frac) < test_mask.size / (2 * np.sqrt(test_mask.size)) - assert model.obs_features == ['u_10m_obs', 'v_10m_obs'] - with tempfile.TemporaryDirectory() as td: - model_kwargs = { - 'input_resolution': {'spatial': '16km', 'temporal': '3600min'}, - 'n_epoch': 3, - 'weight_gen_advers': 0.0, - 'train_gen': True, - 'train_disc': False, - 'checkpoint_int': None, - 'out_dir': os.path.join(td, 'test_{epoch}'), - } - - model.train(batcher, **model_kwargs) - - loaded = model.load(os.path.join(td, 'test_2')) - loaded.train(batcher, **model_kwargs) - - x = RANDOM_GENERATOR.uniform(0, 1, (4, 30, 30, len(FEATURES_W))) - u10m_obs = RANDOM_GENERATOR.uniform(0, 1, (4, 60, 60, 1)) - v10m_obs = RANDOM_GENERATOR.uniform(0, 1, (4, 60, 60, 1)) - mask = RANDOM_GENERATOR.choice([True, False], (60, 60, 1), p=[0.9, 0.1]) - u10m_obs[:, mask] = np.nan - v10m_obs[:, mask] = np.nan - - with pytest.raises(RuntimeError): - y = model.generate(x, exogenous_data=None) - - exo_tmp = { - 'u_10m_obs': { - 'steps': [{'model': 0, 'combine_type': 'layer', 'data': u10m_obs}] - }, - 'v_10m_obs': { - 'steps': [{'model': 0, 'combine_type': 'layer', 'data': v10m_obs}] - }, - } - y = model.generate(x, exogenous_data=exo_tmp) - - assert y.dtype == np.float32 - assert y.shape[0] == x.shape[0] - assert y.shape[1] == x.shape[1] * 2 - assert y.shape[2] == x.shape[2] * 2 - assert y.shape[3] == len(FEATURES_W) diff --git a/tests/training/test_train_exo.py b/tests/training/test_train_exo.py index 02b283bf8a..86cb157663 100644 --- a/tests/training/test_train_exo.py +++ b/tests/training/test_train_exo.py @@ -19,24 +19,24 @@ @pytest.mark.parametrize( - ('CustomLayer', 'features', 'lr_only_features', 'mode'), + ('CustomLayer', 'lr_features', 'hr_out_features', 'mode'), [ - ('Sup3rAdder', FEATURES_W, ['temperature_100m'], 'lazy'), - ('Sup3rConcat', FEATURES_W, ['temperature_100m'], 'lazy'), - ('Sup3rAdder', FEATURES_W[1:], [], 'lazy'), - ('Sup3rConcat', FEATURES_W[1:], [], 'lazy'), - ('Sup3rConcat', FEATURES_W[1:], [], 'eager'), + ('Sup3rAdder', FEATURES_W, FEATURES_W[1:-1], 'lazy'), + ('Sup3rConcat', FEATURES_W, FEATURES_W[1:-1], 'lazy'), + ('Sup3rAdder', FEATURES_W[1:], FEATURES_W[1:-1], 'lazy'), + ('Sup3rConcat', FEATURES_W[1:], FEATURES_W[1:-1], 'lazy'), + ('Sup3rConcat', FEATURES_W[1:], FEATURES_W[1:-1], 'eager'), ], ) def test_wind_hi_res_topo( - CustomLayer, features, lr_only_features, mode, gen_config_with_topo + CustomLayer, lr_features, hr_out_features, mode, gen_config_with_topo ): """Test a special wind model for non cc with the custom Sup3rAdder or Sup3rConcat layer that adds/concatenates hi-res topography in the middle of the network.""" kwargs = { 'file_paths': pytest.FP_WTK, - 'features': features, + 'features': lr_features, 'target': TARGET_W, 'shape': SHAPE, } @@ -54,7 +54,8 @@ def test_wind_hi_res_topo( t_enhance=1, sample_shape=(20, 20, 1), feature_sets={ - 'lr_only_features': lr_only_features, + 'lr_features': lr_features, + 'hr_out_features': hr_out_features, 'hr_exo_features': ['topography'], }, mode=mode, @@ -80,16 +81,18 @@ def test_wind_hi_res_topo( out_dir=os.path.join(td, 'test_{epoch}'), ) - assert model.lr_features == [f.lower() for f in features] - assert model.hr_out_features == ['u_100m', 'v_100m'] + assert model.lr_features == [f.lower() for f in lr_features] + assert model.hr_out_features == [f.lower() for f in hr_out_features] assert model.hr_exo_features == ['topography'] assert 'test_0' in os.listdir(td) - assert model.meta['hr_out_features'] == ['u_100m', 'v_100m'] + assert model.meta['hr_out_features'] == [ + f.lower() for f in hr_out_features + ] assert model.meta['class'] == 'Sup3rGan' assert 'topography' in batcher.hr_exo_features assert 'topography' not in model.hr_out_features - x = RANDOM_GENERATOR.uniform(0, 1, (4, 30, 30, len(features))) + x = RANDOM_GENERATOR.uniform(0, 1, (4, 30, 30, len(lr_features))) hi_res_topo = RANDOM_GENERATOR.uniform(0, 1, (4, 60, 60, 1)) with pytest.raises(RuntimeError): @@ -108,4 +111,4 @@ def test_wind_hi_res_topo( assert y.shape[0] == x.shape[0] assert y.shape[1] == x.shape[1] * 2 assert y.shape[2] == x.shape[2] * 2 - assert y.shape[3] == len(features) - len(lr_only_features) - 1 + assert y.shape[3] == len(hr_out_features) diff --git a/tests/training/test_train_exo_cc.py b/tests/training/test_train_exo_cc.py index 30225c8d4f..6d5622895e 100644 --- a/tests/training/test_train_exo_cc.py +++ b/tests/training/test_train_exo_cc.py @@ -19,16 +19,16 @@ @pytest.mark.parametrize( - ('CustomLayer', 'features', 'lr_only_features'), + ('CustomLayer', 'lr_features', 'hr_out_features'), [ - ('Sup3rAdder', FEATURES_W, ['temperature_100m']), - ('Sup3rConcat', FEATURES_W, ['temperature_100m']), - ('Sup3rAdder', FEATURES_W[1:], []), - ('Sup3rConcat', FEATURES_W[1:], []), + ('Sup3rAdder', FEATURES_W, FEATURES_W[1:-1]), + ('Sup3rConcat', FEATURES_W, FEATURES_W[1:-1]), + ('Sup3rAdder', FEATURES_W[1:], FEATURES_W[1:-1]), + ('Sup3rConcat', FEATURES_W[1:], FEATURES_W[1:-1]), ], ) def test_wind_hi_res_topo( - CustomLayer, features, lr_only_features, gen_config_with_topo + CustomLayer, lr_features, hr_out_features, gen_config_with_topo ): """Test a special wind cc model with the custom Sup3rAdder or Sup3rConcat layer that adds/concatenates hi-res topography in the middle of the @@ -36,7 +36,7 @@ def test_wind_hi_res_topo( handler = DataHandlerH5WindCC( pytest.FP_WTK, - features=features, + features=lr_features, target=TARGET_W, shape=SHAPE, time_slice=slice(None, None, 2), @@ -50,7 +50,8 @@ def test_wind_hi_res_topo( s_enhance=2, sample_shape=(20, 20), feature_sets={ - 'lr_only_features': lr_only_features, + 'lr_features': lr_features, + 'hr_out_features': hr_out_features, 'hr_exo_features': ['topography'], }, mode='eager', @@ -73,16 +74,16 @@ def test_wind_hi_res_topo( out_dir=os.path.join(td, 'test_{epoch}'), ) - assert model.lr_features == lowered(features) - assert model.hr_out_features == ['u_100m', 'v_100m'] + assert model.lr_features == lowered(lr_features) + assert model.hr_out_features == lowered(hr_out_features) assert model.hr_exo_features == ['topography'] assert 'test_0' in os.listdir(td) - assert model.meta['hr_out_features'] == ['u_100m', 'v_100m'] + assert model.meta['hr_out_features'] == lowered(hr_out_features) assert model.meta['class'] == 'Sup3rGan' assert 'topography' in batcher.hr_exo_features assert 'topography' not in model.hr_out_features - x = RANDOM_GENERATOR.uniform(0, 1, (4, 30, 30, len(features))) + x = RANDOM_GENERATOR.uniform(0, 1, (4, 30, 30, len(lr_features))) hi_res_topo = RANDOM_GENERATOR.uniform(0, 1, (4, 60, 60, 1)) with pytest.raises(RuntimeError): @@ -101,4 +102,4 @@ def test_wind_hi_res_topo( assert y.shape[0] == x.shape[0] assert y.shape[1] == x.shape[1] * 2 assert y.shape[2] == x.shape[2] * 2 - assert y.shape[3] == x.shape[3] - len(lr_only_features) - 1 + assert y.shape[3] == len(hr_out_features) diff --git a/tests/training/test_train_exo_dc.py b/tests/training/test_train_exo_dc.py index 34ccfca83b..fea1a192dc 100644 --- a/tests/training/test_train_exo_dc.py +++ b/tests/training/test_train_exo_dc.py @@ -43,8 +43,11 @@ def test_wind_dc_hi_res_topo(CustomLayer): n_batches=1, s_enhance=2, sample_shape=(20, 20, 8), - feature_sets={'hr_exo_features': ['topography']}, - ) + feature_sets={ + 'lr_features': ['u_100m', 'v_100m', 'topography'], + 'hr_out_features': ['u_100m', 'v_100m'], + 'hr_exo_features': ['topography']}, + ) batcher = BatchHandlerTesterDC( train_containers=[handler], @@ -55,7 +58,10 @@ def test_wind_dc_hi_res_topo(CustomLayer): n_batches=1, s_enhance=2, sample_shape=(10, 10, 8), - feature_sets={'hr_exo_features': ['topography']}, + feature_sets={ + 'lr_features': ['u_100m', 'v_100m', 'topography'], + 'hr_out_features': ['u_100m', 'v_100m'], + 'hr_exo_features': ['topography']}, ) gen_model = [ diff --git a/tests/training/test_train_gan.py b/tests/training/test_train_gan.py index 5ed5a6bbf5..1a8f5f0100 100644 --- a/tests/training/test_train_gan.py +++ b/tests/training/test_train_gan.py @@ -210,6 +210,10 @@ def test_train(fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, n_epoch=8): learning_rate=lr, loss={'MeanAbsoluteError': {}, 'MeanSquaredError': {}}, ) + dummy.set_model_params( + input_resolution={'spatial': '30km', 'temporal': '60min'}, + batch_handler=batch_handler, + ) for batch in batch_handler: out_og = model._tf_generate(batch.low_res) @@ -252,7 +256,7 @@ def test_train(fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, n_epoch=8): 'loss_func', [ {'SlicedWassersteinLoss': {}}, - {'GeothermalPhysicsLoss': {'input_features': ['u_100m']}}, + {'GeothermalPhysicsLoss': {'gen_features': ['u_100m']}}, ], ) def test_train_with_custom_loss(loss_func, n_epoch=8): @@ -406,7 +410,9 @@ def test_input_res_check(): with pytest.raises(RuntimeError): model.set_model_params( - input_resolution={'spatial': '22km', 'temporal': '9min'} + input_resolution={'spatial': '22km', 'temporal': '9min'}, + s_enhance=3, + t_enhance=4, ) diff --git a/tests/training/test_train_solar.py b/tests/training/test_train_solar.py index 83860808d3..0b82b30a54 100644 --- a/tests/training/test_train_solar.py +++ b/tests/training/test_train_solar.py @@ -51,7 +51,10 @@ def test_solar_cc_model(hr_steps): s_enhance=1, t_enhance=8, sample_shape=(20, 20, hr_steps), - feature_sets={'lr_only_features': ['clearsky_ghi', 'ghi']}, + feature_sets={ + 'lr_features': FEATURES_S, + 'hr_out_features': ['clearsky_ratio'], + }, ) fp_gen = os.path.join(CONFIG_DIR, 'sup3rcc/gen_solar_1x_8x_1f.json') @@ -59,7 +62,8 @@ def test_solar_cc_model(hr_steps): Sup3rGan.seed() model = SolarCC( - fp_gen, fp_disc, learning_rate=1e-4, loss='MeanAbsoluteError' + fp_gen, fp_disc, learning_rate=1e-4, loss='MeanAbsoluteError', + t_enhance=8 ) with tempfile.TemporaryDirectory() as td: @@ -126,7 +130,10 @@ def test_solar_cc_model_spatial(): s_enhance=5, t_enhance=1, sample_shape=(20, 20), - feature_sets={'lr_only_features': ['clearsky_ghi', 'ghi']}, + feature_sets={ + 'lr_features': FEATURES_S, + 'hr_out_features': ['clearsky_ratio'], + }, ) fp_gen = os.path.join(CONFIG_DIR, 'sup3rcc/gen_solar_5x_1x_1f.json') @@ -178,7 +185,10 @@ def test_solar_custom_loss(): s_enhance=1, t_enhance=8, sample_shape=(5, 5, 24), - feature_sets={'lr_only_features': ['clearsky_ghi', 'ghi']}, + feature_sets={ + 'lr_features': FEATURES_S, + 'hr_out_features': ['clearsky_ratio'], + }, ) fp_gen = os.path.join(CONFIG_DIR, 'sup3rcc/gen_solar_1x_8x_1f.json') diff --git a/tests/training/test_train_with_obs.py b/tests/training/test_train_with_obs.py new file mode 100644 index 0000000000..0a637783b3 --- /dev/null +++ b/tests/training/test_train_with_obs.py @@ -0,0 +1,430 @@ +"""Test the training of super resolution GANs with exogenous observation +data.""" + +import itertools +import os +import tempfile + +import numpy as np +import pytest + +from sup3r.models import Sup3rGan +from sup3r.preprocessing import ( + BatchHandler, + Container, + DataHandler, + DualBatchHandler, + DualRasterizer, +) +from sup3r.preprocessing.samplers import DualSampler +from sup3r.utilities.pytest.helpers import BatchHandlerTesterFactory +from sup3r.utilities.utilities import RANDOM_GENERATOR + +DualBatchHandlerWithObsTester = BatchHandlerTesterFactory( + DualBatchHandler, DualSampler +) + +SHAPE = (20, 20) +FEATURES_W = ['u_10m', 'v_10m'] +TARGET_W = (39.01, -105.15) + + +@pytest.mark.parametrize( + 'gen_config, sample_shape, t_enhance, fp_disc', + [ + ('gen_config_with_obs_2d', (20, 20, 1), 1, pytest.S_FP_DISC), + ('gen_config_with_obs_3d', (20, 20, 10), 2, pytest.ST_FP_DISC), + ], +) +def test_train_proxy_obs( + gen_config, sample_shape, t_enhance, fp_disc, request +): + """Test a special model which conditions model output on observations + with a ``Sup3rConcatObs`` layer.""" + + gen_config = request.getfixturevalue(gen_config)() + kwargs = { + 'file_paths': pytest.FP_WTK, + 'features': FEATURES_W, + 'target': TARGET_W, + 'shape': SHAPE, + } + + train_handler = DataHandler(**kwargs, time_slice=slice(None, 3000, 10)) + + val_handler = DataHandler(**kwargs, time_slice=slice(3000, None, 10)) + batcher = BatchHandler( + [train_handler], + [val_handler], + batch_size=2, + n_batches=1, + s_enhance=2, + t_enhance=t_enhance, + sample_shape=sample_shape, + proxy_obs_kwargs={'onshore_obs_frac': {'spatial': 0.1}}, + feature_sets={ + 'lr_features': FEATURES_W, + 'hr_exo_features': [f'{feat}_obs' for feat in FEATURES_W], + 'hr_out_features': FEATURES_W, + }, + ) + + Sup3rGan.seed() + + model = Sup3rGan( + gen_config, + fp_disc, + learning_rate=1e-4, + loss={ + 'GeothermalPhysicsLossWithObs': { + 'gen_features': FEATURES_W, + 'true_features': [f'{feat}_obs' for feat in FEATURES_W], + }, + 'GeothermalPhysicsLoss': {'gen_features': FEATURES_W}, + }, + ) + model.meta['hr_out_features'] = FEATURES_W + with tempfile.TemporaryDirectory() as td: + model_kwargs = { + 'input_resolution': {'spatial': '16km', 'temporal': '3600min'}, + 'n_epoch': 3, + 'weight_gen_advers': 0.0, + 'train_gen': True, + 'train_disc': False, + 'checkpoint_int': None, + 'out_dir': os.path.join(td, 'test_{epoch}'), + } + + model.train(batcher, **model_kwargs) + + loaded = model.load(os.path.join(td, 'test_2')) + loaded.train(batcher, **model_kwargs) + + assert model.obs_features == [f'{feat}_obs' for feat in FEATURES_W] + + if t_enhance == 1: + x = RANDOM_GENERATOR.uniform(0, 1, (4, 30, 30, len(FEATURES_W))) + u10m_obs = RANDOM_GENERATOR.uniform(0, 1, (4, 60, 60, 1)) + v10m_obs = RANDOM_GENERATOR.uniform(0, 1, (4, 60, 60, 1)) + else: + x = RANDOM_GENERATOR.uniform(0, 1, (4, 30, 30, 10, len(FEATURES_W))) + u10m_obs = RANDOM_GENERATOR.uniform(0, 1, (4, 60, 60, 20, 1)) + v10m_obs = RANDOM_GENERATOR.uniform(0, 1, (4, 60, 60, 20, 1)) + mask = RANDOM_GENERATOR.choice( + [True, False], u10m_obs.shape[1:], p=[0.9, 0.1] + ) + u10m_obs[:, mask] = np.nan + v10m_obs[:, mask] = np.nan + + with pytest.raises(RuntimeError): + y = model.generate(x, exogenous_data=None) + + exo_tmp = { + 'u_10m_obs': { + 'steps': [{'model': 0, 'combine_type': 'layer', 'data': u10m_obs}] + }, + 'v_10m_obs': { + 'steps': [{'model': 0, 'combine_type': 'layer', 'data': v10m_obs}] + }, + } + y = model.generate(x, exogenous_data=exo_tmp) + + assert y.dtype == np.float32 + assert y.shape[0] == x.shape[0] + assert y.shape[1] == x.shape[1] * 2 + assert y.shape[2] == x.shape[2] * 2 + assert y.shape[-1] == len(FEATURES_W) + if y.ndim == 5: + assert y.shape[3] == x.shape[3] * t_enhance + + +@pytest.mark.parametrize( + 'gen_config, sample_shape, t_enhance, fp_disc', + [ + ('gen_config_with_obs_2d', (20, 20, 1), 1, pytest.S_FP_DISC), + ('gen_config_with_obs_3d', (20, 20, 10), 2, pytest.ST_FP_DISC), + ], +) +def test_train_real_obs(gen_config, sample_shape, t_enhance, fp_disc, request): + """Test model training with only sparse high resolution ground truth data. + This should skip any calculations involving the discriminator - since + train_disc=False""" + + gen_config = request.getfixturevalue(gen_config)() + kwargs = { + 'features': FEATURES_W, + 'target': TARGET_W, + 'shape': (20, 20), + } + hr_handler = DataHandler( + pytest.FP_WTK, + **kwargs, + time_slice=slice(None, None, 1), + ) + + lr_handler = DataHandler( + pytest.FP_ERA, + features=FEATURES_W, + time_slice=slice(None, None, t_enhance), + ) + + dual_rasterizer = DualRasterizer( + data={'low_res': lr_handler.data, 'high_res': hr_handler.data}, + s_enhance=2, + t_enhance=t_enhance, + run_qa=False, + ) + obs_data = dual_rasterizer.high_res.copy() + for feat in FEATURES_W: + tmp = np.full(obs_data[feat].shape, np.nan) + lat_ids = RANDOM_GENERATOR.choice( + obs_data[feat].shape[0], size=5, replace=False + ) + lon_ids = RANDOM_GENERATOR.choice( + obs_data[feat].shape[1], size=5, replace=False + ) + for ilat, ilon in itertools.product(lat_ids, lon_ids): + tmp[ilat, ilon, :] = obs_data[feat][ilat, ilon] + obs_data[f'{feat}_obs'] = (obs_data[feat].dims, tmp) + + dual_with_obs = Container( + data={ + 'low_res': dual_rasterizer.low_res, + 'high_res': obs_data, + } + ) + + batch_handler = DualBatchHandlerWithObsTester( + train_containers=[dual_with_obs], + val_containers=[], + sample_shape=sample_shape, + batch_size=3, + s_enhance=2, + t_enhance=t_enhance, + n_batches=2, + feature_sets={ + 'lr_features': FEATURES_W, + 'hr_exo_features': [f'{feat}_obs' for feat in FEATURES_W], + 'hr_out_features': [f'{feat}_obs' for feat in FEATURES_W], + }, + mode='lazy', + ) + + for batch in batch_handler: + assert not np.isnan(batch.high_res).all() + assert np.isnan(batch.high_res).any() + + Sup3rGan.seed() + model = Sup3rGan( + gen_config, + fp_disc, + learning_rate=1e-4, + loss={ + 'GeothermalPhysicsLossWithObs': { + 'gen_features': [f'{feat}_obs' for feat in FEATURES_W], + 'true_features': [f'{feat}_obs' for feat in FEATURES_W], + } + }, + ) + + with tempfile.TemporaryDirectory() as td: + model_kwargs = { + 'input_resolution': {'spatial': '30km', 'temporal': '60min'}, + 'n_epoch': 5, + 'weight_gen_advers': 0.0, + 'train_gen': True, + 'train_disc': False, + 'checkpoint_int': 1, + 'out_dir': os.path.join(td, 'test_{epoch}'), + } + + model.train(batch_handler, **model_kwargs) + + tloss = model.history['train_geothermal_physics_loss_with_obs'].values + assert not np.isnan(tloss).any() + + +@pytest.mark.parametrize('lr_only_features', [[], ['temperature_2m']]) +def test_train_proxy_obs_with_topo(lr_only_features, request): + """Test training with topo and obs. Make sure exo features are + properly concatenated.""" + + gen_config = 'gen_config_with_obs_3d_topo' + gen_config = request.getfixturevalue(gen_config)() + kwargs = { + 'file_paths': pytest.FP_WTK, + 'features': [*FEATURES_W, 'topography'], + 'target': TARGET_W, + 'shape': SHAPE, + } + + train_handler = DataHandler(**kwargs, time_slice=slice(None, 3000, 10)) + val_handler = DataHandler(**kwargs, time_slice=slice(3000, None, 10)) + + # Add dummy lr only features + if lr_only_features: + for feat in lr_only_features: + train_handler[feat] = train_handler[FEATURES_W[0]].copy() + val_handler[feat] = val_handler[FEATURES_W[0]].copy() + + batcher = BatchHandler( + [train_handler], + [val_handler], + batch_size=2, + n_batches=1, + s_enhance=2, + t_enhance=2, + sample_shape=(20, 20, 10), + proxy_obs_kwargs={'onshore_obs_frac': {'spatial': 0.1}}, + feature_sets={ + 'lr_features': [*lr_only_features, *FEATURES_W], + 'hr_exo_features': [ + 'topography', + *[f'{feat}_obs' for feat in FEATURES_W], + ], + 'hr_out_features': FEATURES_W, + }, + ) + + Sup3rGan.seed() + + model = Sup3rGan( + gen_config, + pytest.ST_FP_DISC, + learning_rate=1e-4, + loss={ + 'GeothermalPhysicsLossWithObs': { + 'gen_features': FEATURES_W, + 'true_features': [f'{feat}_obs' for feat in FEATURES_W], + } + }, + ) + with tempfile.TemporaryDirectory() as td: + model_kwargs = { + 'input_resolution': {'spatial': '16km', 'temporal': '3600min'}, + 'n_epoch': 3, + 'weight_gen_advers': 0.0, + 'train_gen': True, + 'train_disc': False, + 'checkpoint_int': None, + 'out_dir': os.path.join(td, 'test_{epoch}'), + } + + model.train(batcher, **model_kwargs) + + loss = model.history['train_geothermal_physics_loss_with_obs'].values + assert not np.isnan(loss).any() + gloss = model.history['train_loss_gen'].values + assert not np.isnan(gloss).any() + + +@pytest.mark.parametrize('lr_only_features', [[], ['temperature_2m']]) +def test_train_real_obs_with_topo(lr_only_features, request): + """Test model training with only sparse high resolution ground truth data. + This should skip any calculations involving the discriminator - since + train_disc=False""" + + gen_config = request.getfixturevalue('gen_config_with_obs_3d_topo')() + kwargs = { + 'features': [*FEATURES_W, 'topography'], + 'target': TARGET_W, + 'shape': SHAPE, + } + + hr_handler = DataHandler( + pytest.FP_WTK, + **kwargs, + time_slice=slice(None, None, 1), + ) + + lr_handler = DataHandler( + pytest.FP_ERA, + features=FEATURES_W, + time_slice=slice(None, None, 2), + ) + + # Add dummy lr only features + if lr_only_features: + for feat in lr_only_features: + lr_handler[feat] = lr_handler[FEATURES_W[0]].copy() + + dual_rasterizer = DualRasterizer( + data={'low_res': lr_handler.data, 'high_res': hr_handler.data}, + s_enhance=2, + t_enhance=2, + run_qa=False, + ) + obs_data = dual_rasterizer.high_res.copy() + for feat in FEATURES_W: + tmp = np.full(obs_data[feat].shape, np.nan) + lat_ids = RANDOM_GENERATOR.choice( + obs_data[feat].shape[0], size=5, replace=False + ) + lon_ids = RANDOM_GENERATOR.choice( + obs_data[feat].shape[1], size=5, replace=False + ) + for ilat, ilon in itertools.product(lat_ids, lon_ids): + tmp[ilat, ilon, :] = obs_data[feat][ilat, ilon] + obs_data[f'{feat}_obs'] = (obs_data[feat].dims, tmp) + + dual_with_obs = Container( + data={ + 'low_res': dual_rasterizer.low_res, + 'high_res': obs_data, + } + ) + + batch_handler = DualBatchHandlerWithObsTester( + train_containers=[dual_with_obs], + val_containers=[], + sample_shape=(20, 20, 10), + batch_size=2, + s_enhance=2, + t_enhance=2, + n_batches=1, + feature_sets={ + 'lr_features': [*lr_only_features, *FEATURES_W], + 'hr_exo_features': [ + 'topography', + *[f'{feat}_obs' for feat in FEATURES_W], + ], + 'hr_out_features': FEATURES_W, + }, + mode='lazy', + ) + + Sup3rGan.seed() + model = Sup3rGan( + gen_config, + pytest.ST_FP_DISC, + learning_rate=1e-4, + loss={ + 'MeanAbsoluteError': {}, + 'GeothermalPhysicsLossWithObs': { + 'gen_features': FEATURES_W, + 'true_features': [f'{feat}_obs' for feat in FEATURES_W], + } + }, + ) + + with tempfile.TemporaryDirectory() as td: + model_kwargs = { + 'input_resolution': {'spatial': '30km', 'temporal': '60min'}, + 'n_epoch': 5, + 'weight_gen_advers': 0.0, + 'train_gen': True, + 'train_disc': False, + 'checkpoint_int': 1, + 'out_dir': os.path.join(td, 'test_{epoch}'), + } + + model.train(batch_handler, **model_kwargs) + + tloss = model.history['train_loss_gen'].values + assert not np.isnan(tloss).any() + + model_kwargs['train_disc'] = True + model.train(batch_handler, **model_kwargs) + + tloss_disc = model.history['train_loss_disc'].values + assert not np.isnan(tloss_disc).any() diff --git a/tests/utilities/test_loss_metrics.py b/tests/utilities/test_loss_metrics.py index b32e35b441..9669d58970 100644 --- a/tests/utilities/test_loss_metrics.py +++ b/tests/utilities/test_loss_metrics.py @@ -271,7 +271,7 @@ def test_md_loss(): y = x.copy() md_loss = MaterialDerivativeLoss( - input_features=['u_100m', 'v_100m'] + gen_features=['u_100m', 'v_100m'] ) u_div = md_loss._compute_md(x, feature='u_100m') v_div = md_loss._compute_md(x, feature='v_100m') @@ -301,16 +301,22 @@ def test_multiterm_loss(): y = x.copy() md_loss = MaterialDerivativeLoss( - input_features=['u_100m', 'v_100m', 'temp_100m'] + gen_features=['u_100m', 'v_100m', 'temp_100m'] ) mae_loss = MeanAbsoluteError() fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - model.meta['hr_out_features'] = ['u_100m', 'v_100m', 'temp_100m'] + model.set_model_params( + lr_features=['u_100m', 'v_100m', 'temp_100m'], + hr_out_features=['u_100m', 'v_100m', 'temp_100m'], + input_resolution={'spatial': '12km', 'temporal': '60min'}, + s_enhance=1, + t_enhance=1, + ) multi_loss = model.get_loss_fun({ 'MaterialDerivativeLoss': { - 'input_features': ['u_100m', 'v_100m', 'temp_100m'] + 'gen_features': ['u_100m', 'v_100m', 'temp_100m'] }, 'MeanAbsoluteError': {}, 'term_weights': [0.2, 0.8],