Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ dependencies = [
"netCDF4>=1.5.8,<1.7",
"cftime>=1.6.2",
"matplotlib>=3.1",
"numpy>=1.7.0",
"numpy>=1.7.0,<2.0.0",
"pandas>=2.0",
"pillow>=10.0",
"pytest>=5.2",
Expand Down
233 changes: 118 additions & 115 deletions sup3r/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,37 +495,25 @@ def calc_loss_gen_content(self, hi_res_true, hi_res_gen):
)
return self.loss_fun(hi_res_true[..., slc], hi_res_gen[..., slc])

@staticmethod
@tf.function
def calc_loss_gen_advers(disc_out_gen):
"""Calculate the adversarial component of the loss term for the
generator model.

Parameters
----------
disc_out_gen : tf.Tensor
Raw discriminator outputs from the discriminator model
predicting only on hi_res_gen (not on hi_res_true).

Returns
-------
loss_gen_advers : tf.Tensor
0D tensor generator model loss for the adversarial component of the
generator loss term.
"""

# note that these have flipped labels from the discriminator
# loss because of the opposite optimization goal
loss_gen_advers = tf.nn.sigmoid_cross_entropy_with_logits(
logits=disc_out_gen, labels=tf.ones_like(disc_out_gen)
)
return tf.reduce_mean(loss_gen_advers)

@staticmethod
@tf.function
def calc_loss_disc(disc_out_true, disc_out_gen):
"""Calculate the loss term for the discriminator model (either the
spatial or temporal discriminator).
spatial or temporal discriminator. This uses the relativistic
discriminator loss described in [Wang2018]_.

Note: Instead of training the discriminator to label data as either
real or fake this trains the disc to label data as more or less
realistic. To use this for adversarial loss we simply set
``disc_out_true`` to ``disc_out_gen`` and vice versa, which then
encourages the generator to produce output which is "more realistic"
than the true high-res data.

References
----------
.. [Wang2018] Wang, Xintao, et al. "Esrgan: Enhanced super-resolution
generative adversarial networks." Proceedings of the European
conference on computer vision (ECCV) workshops. 2018.

Parameters
----------
Expand All @@ -542,95 +530,17 @@ def calc_loss_disc(disc_out_true, disc_out_gen):
0D tensor discriminator model loss for either the spatial or
temporal component of the super resolution generated output.
"""

# note that these have flipped labels from the generator
# loss because of the opposite optimization goal
logits = tf.concat([disc_out_true, disc_out_gen], axis=0)
true_logits = disc_out_true - tf.reduce_mean(disc_out_gen)
fake_logits = disc_out_gen - tf.reduce_mean(disc_out_true)
logits = tf.concat([true_logits, fake_logits], axis=0)
labels = tf.concat(
[tf.ones_like(disc_out_true), tf.zeros_like(disc_out_gen)], axis=0
)

loss_disc = tf.nn.sigmoid_cross_entropy_with_logits(
logits=logits, labels=labels
)
return tf.reduce_mean(loss_disc)

@tf.function
def calc_loss(
self,
hi_res_true,
hi_res_gen,
weight_gen_advers=0.001,
train_gen=True,
train_disc=False,
):
"""Calculate the GAN loss function using generated and true high
resolution data.

Parameters
----------
hi_res_true : tf.Tensor
Ground truth high resolution spatiotemporal data.
hi_res_gen : tf.Tensor
Superresolved high resolution spatiotemporal data generated by the
generative model.
weight_gen_advers : float
Weight factor for the adversarial loss component of the generator
vs. the discriminator.
train_gen : bool
True if generator is being trained, then loss=loss_gen
train_disc : bool
True if disc is being trained, then loss=loss_disc

Returns
-------
loss : tf.Tensor
0D tensor representing the loss value for the network being trained
(either generator or one of the discriminators)
loss_details : dict
Namespace of the breakdown of loss components
"""
hi_res_gen = self._combine_loss_input(hi_res_true, hi_res_gen)

if hi_res_gen.shape != hi_res_true.shape:
msg = (
'The tensor shapes of the synthetic output {} and '
'true high res {} did not have matching shape! '
'Check the spatiotemporal enhancement multipliers in your '
'your model config and data handlers.'.format(
hi_res_gen.shape, hi_res_true.shape
)
)
logger.error(msg)
raise RuntimeError(msg)

disc_out_true = self._tf_discriminate(hi_res_true)
disc_out_gen = self._tf_discriminate(hi_res_gen)

loss_gen_content, loss_gen_content_details = (
self.calc_loss_gen_content(hi_res_true, hi_res_gen)
)
loss_gen_advers = self.calc_loss_gen_advers(disc_out_gen)
loss_gen = loss_gen_content + weight_gen_advers * loss_gen_advers

loss_disc = self.calc_loss_disc(disc_out_true, disc_out_gen)

loss = None
if train_gen:
loss = loss_gen
elif train_disc:
loss = loss_disc

loss_details = {
'loss_gen': loss_gen,
'loss_gen_content': loss_gen_content,
'loss_gen_advers': loss_gen_advers,
'loss_disc': loss_disc,
}
loss_details.update(loss_gen_content_details)

return loss, loss_details

def update_adversarial_weights(
self,
history,
Expand Down Expand Up @@ -904,6 +814,89 @@ def train(

batch_handler.stop()

def calc_loss(
self,
hi_res_true,
hi_res_gen,
weight_gen_advers=0.001,
train_gen=True,
train_disc=False,
compute_disc=False,
):
"""Calculate the GAN loss function using generated and true high
resolution data.

Parameters
----------
hi_res_true : tf.Tensor
Ground truth high resolution spatiotemporal data.
hi_res_gen : tf.Tensor
Superresolved high resolution spatiotemporal data generated by the
generative model.
weight_gen_advers : float
Weight factor for the adversarial loss component of the generator
vs. the discriminator.
train_gen : bool
True if generator is being trained, then loss=loss_gen
train_disc : bool
True if disc is being trained, then loss=loss_disc
compute_disc : bool
True if discriminator loss should be computed, even if not being
trained. Outside of generator pre-training this needs to be
tracked to determine if the discriminator is "too good" or "not
good enough"

Returns
-------
loss : tf.Tensor
0D tensor representing the loss value for the network being trained
(either generator or one of the discriminators)
loss_details : dict
Namespace of the breakdown of loss components
"""
hi_res_gen = self._combine_loss_input(hi_res_true, hi_res_gen)

if hi_res_gen.shape != hi_res_true.shape:
msg = (
'The tensor shapes of the synthetic output {} and '
'true high res {} did not have matching shape! '
'Check the spatiotemporal enhancement multipliers in your '
'your model config and data handlers.'.format(
hi_res_gen.shape, hi_res_true.shape
)
)
logger.error(msg)
raise RuntimeError(msg)

disc_out_true = self._tf_discriminate(hi_res_true)
disc_out_gen = self._tf_discriminate(hi_res_gen)

loss_details = {}
loss = None

if compute_disc or train_disc:
loss_details['loss_disc'] = self.calc_loss_disc(
disc_out_true=disc_out_true, disc_out_gen=disc_out_gen
)

if train_gen:
loss_gen_content, loss_gen_content_details = (
self.calc_loss_gen_content(hi_res_true, hi_res_gen)
)
loss_gen_advers = self.calc_loss_disc(
disc_out_true=disc_out_gen, disc_out_gen=disc_out_true
)
loss = loss_gen_content + weight_gen_advers * loss_gen_advers
loss_details['loss_gen'] = loss
loss_details['loss_gen_content'] = loss_gen_content
loss_details['loss_gen_advers'] = loss_gen_advers
loss_details.update(loss_gen_content_details)

elif train_disc:
loss = loss_details['loss_disc']

return loss, loss_details

def calc_val_loss(self, batch_handler, weight_gen_advers):
"""Calculate the validation loss at the current state of model training

Expand All @@ -925,11 +918,7 @@ def calc_val_loss(self, batch_handler, weight_gen_advers):
hi_res_exo = self.get_hr_exo_input(batch.high_res)
hi_res_gen = self._tf_generate(batch.low_res, hi_res_exo)
_, v_loss_details = self.calc_loss(
batch.high_res,
hi_res_gen,
weight_gen_advers=weight_gen_advers,
train_gen=False,
train_disc=False,
batch.high_res, hi_res_gen, weight_gen_advers=weight_gen_advers
)
self._val_record = self.update_loss_details(
self._val_record,
Expand Down Expand Up @@ -1004,6 +993,7 @@ def _train_batch(
optimizer=self.optimizer,
train_gen=True,
train_disc=False,
compute_disc=train_disc,
multi_gpu=multi_gpu,
)

Expand All @@ -1024,7 +1014,9 @@ def _train_batch(
b_loss_details['disc_train_frac'] = float(trained_disc)
return b_loss_details

def _post_batch(self, ib, b_loss_details, loss_mean_window, n_batches):
def _post_batch(
self, ib, b_loss_details, loss_mean_window, n_batches, previous_means
):
"""Update loss details after the current batch and write to log.

Parameters
Expand All @@ -1037,12 +1029,19 @@ def _post_batch(self, ib, b_loss_details, loss_mean_window, n_batches):
Number of batches to use in the running loss means
n_batches : int
Number of batches in an epoch
previous_means : dict
Dictionary of previous loss means over the loss_mean_window

Returns
-------
loss_means : dict
Dictionary of running loss means
"""
# set default values for when either disc / gen is not trained for the
# last batch
for key, val in previous_means.items():
if key.startswith('train_'):
b_loss_details.setdefault(key.replace('train_', ''), val)

self._train_record = self.update_loss_details(
self._train_record,
Expand Down Expand Up @@ -1171,7 +1170,11 @@ def _train_epoch(
)

loss_means = self._post_batch(
ib, b_loss_details, loss_mean_window, len(batch_handler)
ib,
b_loss_details,
loss_mean_window,
len(batch_handler),
loss_means,
)

self.total_batches += len(batch_handler)
Expand Down
Loading