Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 122 additions & 59 deletions pixi.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ dependencies = [
"pytest>=5.2",
"scipy>=1.0.0",
"sphinx>=7.0",
"xarray>=2023.0"
"xarray>=2024.0",
"zarr>=2.0.0,<4",
]

# If used, cause glibc conflict
Expand Down Expand Up @@ -298,7 +299,7 @@ matplotlib = ">=3.1"
numpy = "~=1.7"
pandas = ">=2.0"
scipy = ">=1.0.0"
xarray = ">=2023.0"
xarray = ">=2024.0"

[tool.pixi.pypi-dependencies]
NREL-sup3r = { path = ".", editable = true }
Expand Down
1 change: 1 addition & 0 deletions sup3r/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .multi_step import MultiStepGan, MultiStepSurfaceMetGan, SolarMultiStepGan
from .solar_cc import SolarCC
from .surface import SurfaceSpatialMetModel
from .training_config import TrainingConfig
from .with_obs import Sup3rGanWithObs

SPATIAL_FIRST_MODELS = (MultiStepSurfaceMetGan, SolarMultiStepGan)
145 changes: 145 additions & 0 deletions sup3r/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def __init__(self):
self._stdevs = None
self._train_record = pd.DataFrame()
self._val_record = pd.DataFrame()
self._swa_weights = None
self._swa_n = 0
self._swa_enabled = False
self._pre_swa_weights = None

def load_network(self, model, name):
"""Load a CustomNetwork object from hidden layers config, .json file
Expand Down Expand Up @@ -722,6 +726,147 @@ def early_stop(history, column, threshold=0.005, n_epoch=5):

return stop

def enable_swa(self):
"""Enable stochastic weight averaging."""
self._swa_enabled = True
self._swa_weights = None
self._swa_n = 0
logger.info('SWA enabled')

def update_swa(self):
"""Update the SWA running average with current model weights.

This should be called at the end of epochs where you want to include
the current weights in the average (e.g., after each epoch in the
last 25% of training, or at the end of each LR cycle).
"""
if not self._swa_enabled:
return

current_weights = [w.numpy() for w in self.weights]

if self._swa_weights is None:
# First snapshot
self._swa_weights = current_weights
self._swa_n = 1
else:
for i, w in enumerate(current_weights):
self._swa_weights[i] = (
self._swa_weights[i] * self._swa_n + w
) / (self._swa_n + 1)
self._swa_n += 1

logger.info(f'Updated SWA weights (n={self._swa_n})')

def apply_swa(self, config, epoch, extras):
"""Apply SWA updates if enabled in config.

Parameters
----------
config : TrainingConfig
Training configuration object.
epoch : int
Current epoch number.
extras : dict
Dictionary of extra information to log for the current epoch.
This is updated in-place with any SWA-related information
(e.g., swa_n) and returned at the end.

Returns
-------
extras : dict
Updated dictionary of extra information to log for the
current epoch.
"""
if config.swa_start is not None and epoch >= config.swa_start:
# Switch to constant LR if specified (only once)
if config.swa_lr is not None and epoch == config.swa_start:
self.update_optimizer('all', learning_rate=config.swa_lr)
logger.info(f'Switched to SWA constant LR: {config.swa_lr}')

# Update SWA weights at specified frequency
if (epoch - config.swa_start) % config.swa_freq == 0:
self.update_swa()
extras['swa_n'] = self._swa_n

return extras

def swap_swa_weights(self):
"""Replace current model weights with SWA averaged weights.

Call this after training is complete to use the averaged weights.
"""
if self._swa_weights is None:
logger.warning('No SWA weights to swap')
return

logger.info(
f'Swapping to SWA weights (averaged over {self._swa_n} snapshots)'
)

# Store original weights as backup
self._pre_swa_weights = [w.numpy() for w in self.weights]

# Set model weights to SWA averages
for weight_var, swa_weight in zip(self.weights, self._swa_weights):
weight_var.assign(swa_weight)

def restore_pre_swa_weights(self):
"""Restore weights from before SWA swap (for comparison/debugging)."""
if self._pre_swa_weights is None:
logger.warning('No pre-SWA weights to restore')
return

for weight_var, pre_swa_weight in zip(
self.weights, self._pre_swa_weights
):
weight_var.assign(pre_swa_weight)
logger.info('Restored pre-SWA weights')

def update_bn_stats(self, batch_handler, n_batches=None):
"""Update batch normalization statistics after swapping to SWA weights.

This is critical because BN layers have running statistics computed
during training with the original weights, not the SWA averaged
weights.

Parameters
----------
batch_handler : sup3r.preprocessing.BatchHandler
BatchHandler to iterate through for BN updates
n_batches : int | None
Number of batches to use. If None, uses all available batches.
"""
has_bn_layers = any(
isinstance(layer, (tf.keras.layers.BatchNormalization,))
for layer in self.generator.layers
)
if not has_bn_layers:
logger.info(
'No batch normalization layers found, skipping BN stats update'
)
return

logger.info('Updating batch normalization statistics for SWA model...')

# Reset BN layer statistics
for layer in self.generator.layers:
if isinstance(layer, (tf.keras.layers.BatchNormalization,)):
layer.moving_mean.assign(tf.zeros_like(layer.moving_mean))
layer.moving_variance.assign(
tf.ones_like(layer.moving_variance)
)

# Do forward passes to recompute statistics
count = 0
for batch in batch_handler:
if n_batches is not None and count >= n_batches:
break
_ = self.generate(batch.low_res, norm_in=True)
count += 1

logger.info(f'Updated BN stats using {count} batches')

@abstractmethod
def save(self, out_dir):
"""Save the model with its sub-networks to a directory.
Expand Down
Loading
Loading