Skip to content
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f056f4c
change base loss name - not always used for physics loss
bnb32 Mar 6, 2026
cf34689
Extend obs model to enable training on just sparse high res data
bnb32 Mar 9, 2026
a80d2eb
Remove TODO comments regarding observation handling in Sup3rGanWithObs
bnb32 Mar 9, 2026
b264af1
refact: move proxy obs sampling logic to sampler objects. removed Sup…
bnb32 Mar 12, 2026
bba5e91
`sparse_disc` arg added to Sup3rGan - True if discriminator can handl…
bnb32 Mar 12, 2026
129b99b
test fixes
bnb32 Mar 13, 2026
8966ee6
test fixes
bnb32 Mar 13, 2026
65a681d
test fixes
bnb32 Mar 13, 2026
0e87c25
refact: remove unused batch handler attribute checks and improve feat…
bnb32 Mar 14, 2026
6a5d458
test: add tests for sampler behavior with proxy observations
bnb32 Mar 14, 2026
03e7493
make sure single sampler and dual sampler index obs and hr features c…
bnb32 Mar 14, 2026
473a95d
fix: hr_feature ordering with exo features
bnb32 Mar 14, 2026
19ca941
Refactor model parameter setting and update feature handling
bnb32 Mar 15, 2026
bbdcb0a
fix: update enhancement factor calculations and model initialization …
bnb32 Mar 15, 2026
e15ae1b
fix: update model parameter settings for multi-step GAN tests
bnb32 Mar 15, 2026
17d568c
fix: add s_enhance and t_enhance parameters to multiterm loss test
bnb32 Mar 15, 2026
76484e8
small doc string clarification
bnb32 Mar 18, 2026
b248a82
doc string edit and notebook typo removed
bnb32 Mar 18, 2026
04208ad
log message about exo rasterizer selection
bnb32 Mar 18, 2026
bd9a8b9
fix: typo
bnb32 Mar 18, 2026
012b847
Merge branch 'main' into bnb/obs_model
bnb32 Mar 19, 2026
93791d0
*removed sparse_disc argument - use train_disc arg to skip disc calcs…
bnb32 Mar 19, 2026
b84ab4f
refactor: update daily_reduction checks to use set for improved perfo…
bnb32 Mar 20, 2026
65efd32
fix: still need `sparse_disc` flag for removing obs features from dis…
bnb32 Mar 21, 2026
16474dd
fix: previous logic was overwriting layer.features and including laye…
bnb32 Mar 21, 2026
ac47940
fix: lock on gradient - without it multi gpu was failing suddenly.
bnb32 Mar 22, 2026
514b7c9
fix: update loss assertions to check for NaN values and adjust traini…
bnb32 Mar 22, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ xarray = ">=2023.0"

[tool.pixi.pypi-dependencies]
NREL-sup3r = { path = ".", editable = true }
NREL-rex = { version = ">=0.2.87" }
NREL-rex = { version = ">=0.2.91" }
NREL-phygnn = { version = ">=0.0.23" }
NREL-gaps = { version = ">=0.6.13" }
NREL-farms = { version = ">=1.0.4" }
Expand Down
1 change: 0 additions & 1 deletion sup3r/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,5 @@
from .multi_step import MultiStepGan, MultiStepSurfaceMetGan, SolarMultiStepGan
from .solar_cc import SolarCC
from .surface import SurfaceSpatialMetModel
from .with_obs import Sup3rGanWithObs

SPATIAL_FIRST_MODELS = (MultiStepSurfaceMetGan, SolarMultiStepGan)
111 changes: 63 additions & 48 deletions sup3r/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from sup3r.utilities import VERSION_RECORD
from sup3r.utilities.utilities import Timer, camel_to_underscore, safe_cast

from .utilities import SUP3R_LAYERS, SUP3R_OBS_LAYERS, TensorboardMixIn
from .utilities import SUP3R_LAYERS, TensorboardMixIn

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -460,33 +460,46 @@ def _combine_loss_input(self, hi_res_true, hi_res_gen):
hi_res_gen = tf.concat((hi_res_gen, *exo_data), axis=-1)
return hi_res_gen

def _get_loss_inputs(self, hi_res_true, hi_res_gen, loss_func):
def _get_loss_inputs(self, hi_res_gen, hi_res_true, loss_func):
"""Get inputs for the given loss function according to the required
input features"""

msg = (
f'{loss_func} requires input features: '
f'{loss_func.input_features}, but these are not found '
f'in the model output features: {self.hr_out_features}'
)
if not all(
f in self.hr_out_features for f in loss_func.input_features
gen_feats = getattr(loss_func, 'gen_features', 'all')
true_feats = getattr(loss_func, 'true_features', 'all')

if gen_feats != 'all' and not all(
f in self.hr_features for f in gen_feats
):
msg = (
f'{loss_func} requires gen_features: '
f'{loss_func.gen_features}, but these are not found '
f'in the high-resolution features: {self.hr_features}'
)
logger.error(msg)
raise ValueError(msg)

input_inds = [
self.hr_out_features.index(f) for f in loss_func.input_features
]
hr_true = tf.stack(
[hi_res_true[..., idx] for idx in input_inds],
axis=-1,
)
hr_gen = tf.stack(
[hi_res_gen[..., idx] for idx in input_inds],
axis=-1,
)
return hr_true, hr_gen
if true_feats != 'all' and not all(
f in self.hr_features for f in true_feats
):
msg = (
f'{loss_func} requires true_features: '
f'{loss_func.true_features}, but these are not found '
f'in the high-resolution output features: {self.hr_features}'
)
logger.error(msg)
raise ValueError(msg)

if gen_feats == 'all':
gen_feats = self.hr_features
if true_feats == 'all':
true_feats = self.hr_features

gen_inds = [self.hr_features.index(f) for f in gen_feats]
true_inds = [self.hr_features.index(f) for f in true_feats]

hr_true = tf.gather(hi_res_true, true_inds, axis=-1)
hr_gen = tf.gather(hi_res_gen, gen_inds, axis=-1)
return hr_gen, hr_true

def get_loss_fun(self, loss):
"""Get full, possibly multi-term, loss function from the provided str
Expand Down Expand Up @@ -519,20 +532,14 @@ def get_loss_fun(self, loss):
loss_funcs = [self._get_loss_fun({ln: loss[ln]}) for ln in lns]
weights = copy.deepcopy(loss).pop('term_weights', [1.0] * len(lns))

def loss_fun(hi_res_true, hi_res_gen):
def loss_fun(hi_res_gen, hi_res_true):
loss_details = {}
loss = 0
for i, (ln, loss_func) in enumerate(zip(lns, loss_funcs)):
if (
not hasattr(loss_func, 'input_features')
or loss_func.input_features == 'all'
):
val = loss_func(hi_res_true, hi_res_gen)
else:
hr_true, hr_gen = self._get_loss_inputs(
hi_res_true, hi_res_gen, loss_func
)
val = loss_func(hr_true, hr_gen)
hr_gen, hr_true = self._get_loss_inputs(
hi_res_gen, hi_res_true, loss_func
)
val = loss_func(hr_gen, hr_true)
loss_details[camel_to_underscore(ln)] = val
loss += weights[i] * val
return loss, loss_details
Expand Down Expand Up @@ -1038,19 +1045,24 @@ def run_exo_layer(self, layer, input_array, exogenous_data, norm_in=True):
extras = []
features = getattr(layer, 'features', [layer.name])
exo_features = getattr(layer, 'exo_features', [])
is_obs_layer = isinstance(layer, SUP3R_OBS_LAYERS)
for feat in features + exo_features:
missing_obs = feat in features and feat not in exogenous_data
if is_obs_layer and missing_obs:
missing_feat = feat not in exogenous_data
if missing_feat and '_obs' in feat:
msg = (
f'{feat} does not match any features in exogenous_data '
f'({list(exogenous_data)}). Will run without this '
'observation feature.'
f'({list(exogenous_data)}). Will try to run without this '
'feature.'
)
logger.warning(msg)
continue
msg = f'exogenous_data is missing required feature "{feat}"'
assert feat in exogenous_data, msg
elif missing_feat:
msg = (
f'{feat} does not match any features in exogenous_data '
f'({list(exogenous_data)}). This feature is required for '
f'layer {layer.name}.'
)
logger.error(msg)
raise KeyError(msg)
exo = exogenous_data.get_combine_type_data(feat, 'layer')
exo = self._reshape_norm_exo(
input_array,
Expand Down Expand Up @@ -1195,19 +1207,20 @@ def _tf_generate(self, low_res, hi_res_exo=None):
"""
hi_res = self.generator.layers[0](low_res)
layer_num = 1
try:
for i, layer in enumerate(self.generator.layers[1:]):
for i, layer in enumerate(self.generator.layers[1:]):
try:
layer_num = i + 1
if isinstance(layer, SUP3R_LAYERS):
hi_res = self._run_exo_layer(layer, hi_res, hi_res_exo)
else:
hi_res = layer(hi_res)
except Exception as e:
msg = 'Could not run layer #{} "{}" on tensor of shape {}'.format(
layer_num, layer, hi_res.shape
)
logger.error(msg)
raise RuntimeError(msg) from e
except Exception as e:
msg = (
f'Could not run layer #{layer_num} "{layer}" on tensor '
f'of shape {hi_res.shape}'
)
logger.error(msg)
raise RuntimeError(msg) from e

return hi_res

Expand All @@ -1218,7 +1231,9 @@ def _get_hr_exo_and_loss(
**calc_loss_kwargs,
):
"""Get high-resolution exogenous data, generate synthetic output, and
compute loss."""
compute loss. Obs features (if present at the end of hi_res_true) are
extracted and added to exo_data, and trimmed from hi_res_true before
loss calculation."""
hi_res_exo = self.get_hr_exo_input(hi_res_true)
hi_res_gen = self._tf_generate(low_res, hi_res_exo)
loss, loss_details = self.calc_loss(
Expand Down
75 changes: 29 additions & 46 deletions sup3r/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def _tf_discriminate(self, hi_res):
out : np.ndarray
Discriminator output logits
"""

out = self.discriminator.layers[0](hi_res)
layer_num = 1
try:
Expand Down Expand Up @@ -392,7 +393,7 @@ def weights(self):
"""
return self.generator_weights + self.discriminator_weights

def init_weights(self, lr_shape, hr_shape, device=None):
def init_weights(self, lr_shape, hr_shape, train_disc=False, device=None):
"""Initialize the generator and discriminator weights with device
placement.

Expand All @@ -406,12 +407,15 @@ def init_weights(self, lr_shape, hr_shape, device=None):
Shape of one batch of high res input data for sup3r resolution.
Note that the batch size (axis=0) must be included, but the actual
batch size doesnt really matter.
train_disc : bool
Whether to initialize the discriminator weights. If False, only the
generator weights will be initialized.
device : str | None
Option to place model weights on a device. If None,
self.default_device will be used.
"""

if not self.generator_weights:
if not self.generator_weights or not self.discriminator_weights:
if device is None:
device = self.default_device

Expand All @@ -426,7 +430,7 @@ def init_weights(self, lr_shape, hr_shape, device=None):

with tf.device(device):
hr_exo_data = {}
for feature in self.hr_exo_features + self.obs_features:
for feature in self.hr_exo_features:
hr_exo_data[feature] = hr_exo
out = self._tf_generate(low_res, hr_exo_data)
msg = (
Expand All @@ -435,7 +439,9 @@ def init_weights(self, lr_shape, hr_shape, device=None):
f'{len(self.hr_out_features)}'
)
assert out.shape[-1] == len(self.hr_out_features), msg
_ = self._tf_discriminate(hi_res)

if train_disc:
_ = self._tf_discriminate(hi_res)

@staticmethod
def get_weight_update_fraction(
Expand Down Expand Up @@ -494,14 +500,7 @@ def calc_loss_gen_content(self, hi_res_true, hi_res_gen):
0D tensor generator model loss for the content loss comparing the
hi res ground truth to the hi res synthetically generated output.
"""
slc = (
slice(0, None)
if len(self.hr_exo_features) == 0
else slice(0, -len(self.hr_exo_features))
)
# gen is first since loss can included regularizers which just
# apply to generator output
return self.loss_fun(hi_res_gen[..., slc], hi_res_true[..., slc])
return self.loss_fun(hi_res_gen, hi_res_true)

@staticmethod
@tf.function
Expand Down Expand Up @@ -606,22 +605,6 @@ def update_adversarial_weights(

return weight_gen_advers

@staticmethod
def check_batch_handler_attrs(batch_handler):
"""Not all batch handlers have the following attributes. So we perform
some sanitation before sending to `set_model_params`"""
return {
k: getattr(batch_handler, k, None)
for k in [
'smoothing',
'lr_features',
'hr_exo_features',
'hr_out_features',
'smoothed_features',
]
if hasattr(batch_handler, k)
}

def train(
self,
batch_handler,
Expand Down Expand Up @@ -725,12 +708,8 @@ def train(
self._init_tensorboard_writer(out_dir)

self.set_norm_stats(batch_handler.means, batch_handler.stds)
params = self.check_batch_handler_attrs(batch_handler)
self.set_model_params(
input_resolution=input_resolution,
s_enhance=batch_handler.s_enhance,
t_enhance=batch_handler.t_enhance,
**params,
input_resolution=input_resolution, batch_handler=batch_handler
)

epochs = list(range(n_epoch))
Expand Down Expand Up @@ -881,28 +860,36 @@ def calc_loss(
logger.error(msg)
raise RuntimeError(msg)

disc_out_true = self._tf_discriminate(hi_res_true)
disc_out_gen = self._tf_discriminate(hi_res_gen)

loss_details = {}
loss = None
disc_out_true = None
disc_out_gen = None
loss_gen_advers = None

if compute_disc or train_disc:
if train_disc or compute_disc:
disc_out_true = self._tf_discriminate(hi_res_true)
disc_out_gen = self._tf_discriminate(hi_res_gen)
loss_details['loss_disc'] = self.calc_loss_disc(
disc_out_true=disc_out_true, disc_out_gen=disc_out_gen
)

if train_gen and compute_disc:
loss_gen_advers = self.calc_loss_disc(
disc_out_true=disc_out_gen, disc_out_gen=disc_out_true
)
loss_details['loss_gen_advers'] = loss_gen_advers

if train_gen:
loss_gen_content, loss_gen_content_details = (
self.calc_loss_gen_content(hi_res_true, hi_res_gen)
)
loss_gen_advers = self.calc_loss_disc(
disc_out_true=disc_out_gen, disc_out_gen=disc_out_true
loss = (
loss_gen_content
if loss_gen_advers is None
else weight_gen_advers * loss_gen_advers
)
loss = loss_gen_content + weight_gen_advers * loss_gen_advers
loss_details['loss_gen'] = loss
loss_details['loss_gen_content'] = loss_gen_content
loss_details['loss_gen_advers'] = loss_gen_advers
loss_details.update(loss_gen_content_details)

elif train_disc:
Expand Down Expand Up @@ -1140,11 +1127,7 @@ def _train_epoch(
Namespace of the breakdown of loss components
"""
lr_shape, hr_shape = batch_handler.shapes
self.init_weights(lr_shape, hr_shape)

self.init_weights(
(1, *batch_handler.lr_shape), (1, *batch_handler.hr_shape)
)
self.init_weights(lr_shape, hr_shape, train_disc=train_disc)

disc_th_low = np.min(disc_loss_bounds)
disc_th_high = np.max(disc_loss_bounds)
Expand Down
Loading
Loading