Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f056f4c
change base loss name - not always used for physics loss
bnb32 Mar 6, 2026
cf34689
Extend obs model to enable training on just sparse high res data
bnb32 Mar 9, 2026
a80d2eb
Remove TODO comments regarding observation handling in Sup3rGanWithObs
bnb32 Mar 9, 2026
b264af1
refact: move proxy obs sampling logic to sampler objects. removed Sup…
bnb32 Mar 12, 2026
bba5e91
`sparse_disc` arg added to Sup3rGan - True if discriminator can handl…
bnb32 Mar 12, 2026
129b99b
test fixes
bnb32 Mar 13, 2026
8966ee6
test fixes
bnb32 Mar 13, 2026
65a681d
test fixes
bnb32 Mar 13, 2026
0e87c25
refact: remove unused batch handler attribute checks and improve feat…
bnb32 Mar 14, 2026
6a5d458
test: add tests for sampler behavior with proxy observations
bnb32 Mar 14, 2026
03e7493
make sure single sampler and dual sampler index obs and hr features c…
bnb32 Mar 14, 2026
473a95d
fix: hr_feature ordering with exo features
bnb32 Mar 14, 2026
19ca941
Refactor model parameter setting and update feature handling
bnb32 Mar 15, 2026
bbdcb0a
fix: update enhancement factor calculations and model initialization …
bnb32 Mar 15, 2026
e15ae1b
fix: update model parameter settings for multi-step GAN tests
bnb32 Mar 15, 2026
17d568c
fix: add s_enhance and t_enhance parameters to multiterm loss test
bnb32 Mar 15, 2026
76484e8
small doc string clarification
bnb32 Mar 18, 2026
b248a82
doc string edit and notebook typo removed
bnb32 Mar 18, 2026
04208ad
log message about exo rasterizer selection
bnb32 Mar 18, 2026
bd9a8b9
fix: typo
bnb32 Mar 18, 2026
012b847
Merge branch 'main' into bnb/obs_model
bnb32 Mar 19, 2026
93791d0
*removed sparse_disc argument - use train_disc arg to skip disc calcs…
bnb32 Mar 19, 2026
b84ab4f
refactor: update daily_reduction checks to use set for improved perfo…
bnb32 Mar 20, 2026
65efd32
fix: still need `sparse_disc` flag for removing obs features from dis…
bnb32 Mar 21, 2026
16474dd
fix: previous logic was overwriting layer.features and including laye…
bnb32 Mar 21, 2026
ac47940
fix: lock on gradient - without it multi gpu was failing suddenly.
bnb32 Mar 22, 2026
514b7c9
fix: update loss assertions to check for NaN values and adjust traini…
bnb32 Mar 22, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ xarray = ">=2023.0"

[tool.pixi.pypi-dependencies]
NREL-sup3r = { path = ".", editable = true }
NREL-rex = { version = ">=0.2.87" }
NREL-rex = { version = ">=0.2.91" }
NREL-phygnn = { version = ">=0.0.23" }
NREL-gaps = { version = ">=0.6.13" }
NREL-farms = { version = ">=1.0.4" }
Expand Down
10 changes: 5 additions & 5 deletions sup3r/bias/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ def _reduce_base_data(
)

cs_ratio = (
daily_reduction.lower() in ('avg', 'average', 'mean')
daily_reduction.lower() in {'avg', 'average', 'mean'}
and base_dset == 'clearsky_ratio'
)

Expand All @@ -756,16 +756,16 @@ def _reduce_base_data(
)
assert not np.isnan(base_data).any(), msg

elif daily_reduction.lower() in ('avg', 'average', 'mean'):
elif daily_reduction.lower() in {'avg', 'average', 'mean'}:
base_data = df.groupby('date').mean()['base_data'].values

elif daily_reduction.lower() in ('max', 'maximum'):
elif daily_reduction.lower() in {'max', 'maximum'}:
base_data = df.groupby('date').max()['base_data'].values

elif daily_reduction.lower() in ('min', 'minimum'):
elif daily_reduction.lower() in {'min', 'minimum'}:
base_data = df.groupby('date').min()['base_data'].values

elif daily_reduction.lower() in ('sum', 'total'):
elif daily_reduction.lower() in {'sum', 'total'}:
base_data = df.groupby('date').sum()['base_data'].values

msg = (
Expand Down
7 changes: 2 additions & 5 deletions sup3r/bias/bias_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,8 @@ def local_qdm_bc(
--------
sup3r.bias.qdm.QuantileDeltaMappingCorrection :
Estimate probability distributions required by QDM method
rex.utilities.bc_utils.QuantileDeltaMapping :
Core QDM transformation.

Notes
-----
Expand All @@ -738,11 +740,6 @@ def local_qdm_bc(
Also, :class:`rex.utilities.bc_utils.QuantileDeltaMapping` expects params
to be 2D (space, N-params).

See Also
--------
rex.utilities.bc_utils.QuantileDeltaMapping :
Core QDM transformation.

Examples
--------
>>> unbiased = local_qdm_bc(biased_array, lat_lon_array, "ghi", "rsds",
Expand Down
1 change: 0 additions & 1 deletion sup3r/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,5 @@
from .multi_step import MultiStepGan, MultiStepSurfaceMetGan, SolarMultiStepGan
from .solar_cc import SolarCC
from .surface import SurfaceSpatialMetModel
from .with_obs import Sup3rGanWithObs

SPATIAL_FIRST_MODELS = (MultiStepSurfaceMetGan, SolarMultiStepGan)
151 changes: 92 additions & 59 deletions sup3r/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,23 @@
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from inspect import signature
from threading import Lock
from warnings import warn

import numpy as np
import pandas as pd
import tensorflow as tf
from gaps.config import load_config
from phygnn import CustomNetwork
from tensorflow.keras import optimizers

import sup3r.utilities.loss_metrics
from phygnn import CustomNetwork
from sup3r.preprocessing.data_handlers import ExoData
from sup3r.preprocessing.utilities import numpy_if_tensor
from sup3r.utilities import VERSION_RECORD
from sup3r.utilities.utilities import Timer, camel_to_underscore, safe_cast

from .utilities import SUP3R_LAYERS, SUP3R_OBS_LAYERS, TensorboardMixIn
from .utilities import SUP3R_LAYERS, TensorboardMixIn

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -51,6 +52,7 @@ def __init__(self):
self._stdevs = None
self._train_record = pd.DataFrame()
self._val_record = pd.DataFrame()
self._lock = Lock()

def load_network(self, model, name):
"""Load a CustomNetwork object from hidden layers config, .json file
Expand Down Expand Up @@ -105,10 +107,9 @@ def _load_model_from_string(self, model, name):
return model['meta'][f'config_{name}']['hidden_layers']

msg = (
'Could not load model from json config, need '
'"hidden_layers" key or '
f'"meta/config_{name}/hidden_layers" '
' at top level but only found: {}'.format(model.keys())
'Could not load model from json config, need "hidden_layers" key '
f'or "meta/config_{name}/hidden_layers" at top level but only '
f'found: {model.keys()}'
)
logger.error(msg)
raise KeyError(msg)
Expand Down Expand Up @@ -460,33 +461,44 @@ def _combine_loss_input(self, hi_res_true, hi_res_gen):
hi_res_gen = tf.concat((hi_res_gen, *exo_data), axis=-1)
return hi_res_gen

def _get_loss_inputs(self, hi_res_true, hi_res_gen, loss_func):
def _get_loss_inputs(self, hi_res_gen, hi_res_true, loss_func):
"""Get inputs for the given loss function according to the required
input features"""
generator output features and ground truth features. If the loss
function doesn't specify required features, this will default to using
all output features that are not exogenous features."""

msg = (
f'{loss_func} requires input features: '
f'{loss_func.input_features}, but these are not found '
f'in the model output features: {self.hr_out_features}'
)
if not all(
f in self.hr_out_features for f in loss_func.input_features
):
gen_feats = getattr(loss_func, 'gen_features', 'all')
true_feats = getattr(loss_func, 'true_features', 'all')

if gen_feats == 'all':
gen_feats = self.hr_out_features
if true_feats == 'all':
true_feats = self.hr_out_features

if not all(f in self.hr_out_features for f in gen_feats):
msg = (
f'{loss_func} requires gen_features: '
f'{loss_func.gen_features}, but these are not found in the '
f'high-resolution output features: {self.hr_out_features}'
)
logger.error(msg)
raise ValueError(msg)

input_inds = [
self.hr_out_features.index(f) for f in loss_func.input_features
]
hr_true = tf.stack(
[hi_res_true[..., idx] for idx in input_inds],
axis=-1,
)
hr_gen = tf.stack(
[hi_res_gen[..., idx] for idx in input_inds],
axis=-1,
)
return hr_true, hr_gen
if not all(f in self.hr_features for f in true_feats):
msg = (
f'{loss_func} requires true_features: '
f'{loss_func.true_features}, but these are not found '
f'in the high-resolution features: {self.hr_features}'
)
logger.error(msg)
raise ValueError(msg)

gen_inds = [self.hr_features.index(f) for f in gen_feats]
true_inds = [self.hr_features.index(f) for f in true_feats]

hr_true = tf.gather(hi_res_true, true_inds, axis=-1)
hr_gen = tf.gather(hi_res_gen, gen_inds, axis=-1)
return hr_gen, hr_true

def get_loss_fun(self, loss):
"""Get full, possibly multi-term, loss function from the provided str
Expand Down Expand Up @@ -519,21 +531,24 @@ def get_loss_fun(self, loss):
loss_funcs = [self._get_loss_fun({ln: loss[ln]}) for ln in lns]
weights = copy.deepcopy(loss).pop('term_weights', [1.0] * len(lns))

def loss_fun(hi_res_true, hi_res_gen):
def loss_fun(hi_res_gen, hi_res_true):
loss_details = {}
loss = 0
for i, (ln, loss_func) in enumerate(zip(lns, loss_funcs)):
if (
not hasattr(loss_func, 'input_features')
or loss_func.input_features == 'all'
):
val = loss_func(hi_res_true, hi_res_gen)
else:
hr_true, hr_gen = self._get_loss_inputs(
hi_res_true, hi_res_gen, loss_func
)
val = loss_func(hr_true, hr_gen)
hr_gen, hr_true = self._get_loss_inputs(
hi_res_gen, hi_res_true, loss_func
)
val = loss_func(hr_gen, hr_true)
loss_details[camel_to_underscore(ln)] = val
tf.debugging.assert_all_finite(
val,
message=(
f'NaN or Inf values found for loss term "{ln}" with '
f'value {val} when running loss function {loss_func} '
f'generated tensor of shape {hi_res_gen.shape} and '
f'true tensor of shape {hi_res_true.shape}'
),
)
loss += weights[i] * val
return loss, loss_details

Expand Down Expand Up @@ -1033,24 +1048,39 @@ def run_exo_layer(self, layer, input_array, exogenous_data, norm_in=True):
of shape:
(n_obs, spatial_1, spatial_2, n_features)
(n_obs, spatial_1, spatial_2, n_temporal, n_features)
exogenous_data : dict | ExoData
Special dictionary (class:`ExoData`) of exogenous feature data with
entries describing whether features should be combined at input, a
mid network layer, or with output. This doesn't have to include the
'model' key since this data is for a single step model.
norm_in : bool
Flag to normalize low_res input data if the self._means,
self._stdevs attributes are available. The generator should always
received normalized data with mean=0 stdev=1. This also normalizes
exogenous data.
"""
feat_stack = []
extras = []
features = getattr(layer, 'features', [layer.name])
exo_features = getattr(layer, 'exo_features', [])
is_obs_layer = isinstance(layer, SUP3R_OBS_LAYERS)
for feat in features + exo_features:
missing_obs = feat in features and feat not in exogenous_data
if is_obs_layer and missing_obs:
missing_feat = feat not in exogenous_data
if missing_feat and '_obs' in feat:
msg = (
f'{feat} does not match any features in exogenous_data '
f'({list(exogenous_data)}). Will run without this '
'observation feature.'
f'({list(exogenous_data)}). Will try to run without this '
'feature.'
)
logger.warning(msg)
continue
msg = f'exogenous_data is missing required feature "{feat}"'
assert feat in exogenous_data, msg
if missing_feat:
msg = (
f'{feat} does not match any features in exogenous_data '
f'({list(exogenous_data)}). This feature is required for '
f'layer {layer.name}.'
)
logger.error(msg)
raise KeyError(msg)
exo = exogenous_data.get_combine_type_data(feat, 'layer')
exo = self._reshape_norm_exo(
input_array,
Expand Down Expand Up @@ -1088,7 +1118,7 @@ def generate(
Flag to normalize low_res input data if the self._means,
self._stdevs attributes are available. The generator should always
received normalized data with mean=0 stdev=1. This also normalizes
hi_res_topo.
exogenous data.
un_norm_out : bool
Flag to un-normalize synthetically generated output data to physical
units
Expand Down Expand Up @@ -1203,8 +1233,9 @@ def _tf_generate(self, low_res, hi_res_exo=None):
else:
hi_res = layer(hi_res)
except Exception as e:
msg = 'Could not run layer #{} "{}" on tensor of shape {}'.format(
layer_num, layer, hi_res.shape
msg = (
f'Could not run layer #{layer_num} "{layer}" on tensor '
f'of shape {hi_res.shape}'
)
logger.error(msg)
raise RuntimeError(msg) from e
Expand All @@ -1218,7 +1249,8 @@ def _get_hr_exo_and_loss(
**calc_loss_kwargs,
):
"""Get high-resolution exogenous data, generate synthetic output, and
compute loss."""
compute loss. All hr_exo_features are extracted from hi_res_true and
added to exo_data."""
hi_res_exo = self.get_hr_exo_input(hi_res_true)
hi_res_gen = self._tf_generate(low_res, hi_res_exo)
loss, loss_details = self.calc_loss(
Expand Down Expand Up @@ -1266,15 +1298,16 @@ def get_single_grad(
loss_details : dict
Namespace of the breakdown of loss components
"""
with (
tf.device(device_name),
tf.GradientTape(watch_accessed_variables=False) as tape,
):
tape.watch(training_weights)
loss, loss_details, _, _ = self._get_hr_exo_and_loss(
low_res, hi_res_true, **calc_loss_kwargs
)
grad = tape.gradient(loss, training_weights)
with self._lock:
with (
tf.device(device_name),
tf.GradientTape(watch_accessed_variables=False) as tape,
):
tape.watch(training_weights)
loss, loss_details, _, _ = self._get_hr_exo_and_loss(
low_res, hi_res_true, **calc_loss_kwargs
)
grad = tape.gradient(loss, training_weights)
return grad, loss_details

@abstractmethod
Expand Down
Loading
Loading