diff --git a/pyproject.toml b/pyproject.toml index 80bc45ec41..b745b6dfe6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ "netCDF4>=1.5.8,<1.7", "cftime>=1.6.2", "matplotlib>=3.1", - "numpy>=1.7.0", + "numpy>=1.7.0,<2.0.0", "pandas>=2.0", "pillow>=10.0", "pytest>=5.2", diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 8586352d6d..fe135300de 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -495,37 +495,25 @@ def calc_loss_gen_content(self, hi_res_true, hi_res_gen): ) return self.loss_fun(hi_res_true[..., slc], hi_res_gen[..., slc]) - @staticmethod - @tf.function - def calc_loss_gen_advers(disc_out_gen): - """Calculate the adversarial component of the loss term for the - generator model. - - Parameters - ---------- - disc_out_gen : tf.Tensor - Raw discriminator outputs from the discriminator model - predicting only on hi_res_gen (not on hi_res_true). - - Returns - ------- - loss_gen_advers : tf.Tensor - 0D tensor generator model loss for the adversarial component of the - generator loss term. - """ - - # note that these have flipped labels from the discriminator - # loss because of the opposite optimization goal - loss_gen_advers = tf.nn.sigmoid_cross_entropy_with_logits( - logits=disc_out_gen, labels=tf.ones_like(disc_out_gen) - ) - return tf.reduce_mean(loss_gen_advers) - @staticmethod @tf.function def calc_loss_disc(disc_out_true, disc_out_gen): """Calculate the loss term for the discriminator model (either the - spatial or temporal discriminator). + spatial or temporal discriminator. This uses the relativistic + discriminator loss described in [Wang2018]_. + + Note: Instead of training the discriminator to label data as either + real or fake this trains the disc to label data as more or less + realistic. To use this for adversarial loss we simply set + ``disc_out_true`` to ``disc_out_gen`` and vice versa, which then + encourages the generator to produce output which is "more realistic" + than the true high-res data. + + References + ---------- + .. [Wang2018] Wang, Xintao, et al. "Esrgan: Enhanced super-resolution + generative adversarial networks." Proceedings of the European + conference on computer vision (ECCV) workshops. 2018. Parameters ---------- @@ -542,95 +530,17 @@ def calc_loss_disc(disc_out_true, disc_out_gen): 0D tensor discriminator model loss for either the spatial or temporal component of the super resolution generated output. """ - - # note that these have flipped labels from the generator - # loss because of the opposite optimization goal - logits = tf.concat([disc_out_true, disc_out_gen], axis=0) + true_logits = disc_out_true - tf.reduce_mean(disc_out_gen) + fake_logits = disc_out_gen - tf.reduce_mean(disc_out_true) + logits = tf.concat([true_logits, fake_logits], axis=0) labels = tf.concat( [tf.ones_like(disc_out_true), tf.zeros_like(disc_out_gen)], axis=0 ) - loss_disc = tf.nn.sigmoid_cross_entropy_with_logits( logits=logits, labels=labels ) return tf.reduce_mean(loss_disc) - @tf.function - def calc_loss( - self, - hi_res_true, - hi_res_gen, - weight_gen_advers=0.001, - train_gen=True, - train_disc=False, - ): - """Calculate the GAN loss function using generated and true high - resolution data. - - Parameters - ---------- - hi_res_true : tf.Tensor - Ground truth high resolution spatiotemporal data. - hi_res_gen : tf.Tensor - Superresolved high resolution spatiotemporal data generated by the - generative model. - weight_gen_advers : float - Weight factor for the adversarial loss component of the generator - vs. the discriminator. - train_gen : bool - True if generator is being trained, then loss=loss_gen - train_disc : bool - True if disc is being trained, then loss=loss_disc - - Returns - ------- - loss : tf.Tensor - 0D tensor representing the loss value for the network being trained - (either generator or one of the discriminators) - loss_details : dict - Namespace of the breakdown of loss components - """ - hi_res_gen = self._combine_loss_input(hi_res_true, hi_res_gen) - - if hi_res_gen.shape != hi_res_true.shape: - msg = ( - 'The tensor shapes of the synthetic output {} and ' - 'true high res {} did not have matching shape! ' - 'Check the spatiotemporal enhancement multipliers in your ' - 'your model config and data handlers.'.format( - hi_res_gen.shape, hi_res_true.shape - ) - ) - logger.error(msg) - raise RuntimeError(msg) - - disc_out_true = self._tf_discriminate(hi_res_true) - disc_out_gen = self._tf_discriminate(hi_res_gen) - - loss_gen_content, loss_gen_content_details = ( - self.calc_loss_gen_content(hi_res_true, hi_res_gen) - ) - loss_gen_advers = self.calc_loss_gen_advers(disc_out_gen) - loss_gen = loss_gen_content + weight_gen_advers * loss_gen_advers - - loss_disc = self.calc_loss_disc(disc_out_true, disc_out_gen) - - loss = None - if train_gen: - loss = loss_gen - elif train_disc: - loss = loss_disc - - loss_details = { - 'loss_gen': loss_gen, - 'loss_gen_content': loss_gen_content, - 'loss_gen_advers': loss_gen_advers, - 'loss_disc': loss_disc, - } - loss_details.update(loss_gen_content_details) - - return loss, loss_details - def update_adversarial_weights( self, history, @@ -904,6 +814,89 @@ def train( batch_handler.stop() + def calc_loss( + self, + hi_res_true, + hi_res_gen, + weight_gen_advers=0.001, + train_gen=True, + train_disc=False, + compute_disc=False, + ): + """Calculate the GAN loss function using generated and true high + resolution data. + + Parameters + ---------- + hi_res_true : tf.Tensor + Ground truth high resolution spatiotemporal data. + hi_res_gen : tf.Tensor + Superresolved high resolution spatiotemporal data generated by the + generative model. + weight_gen_advers : float + Weight factor for the adversarial loss component of the generator + vs. the discriminator. + train_gen : bool + True if generator is being trained, then loss=loss_gen + train_disc : bool + True if disc is being trained, then loss=loss_disc + compute_disc : bool + True if discriminator loss should be computed, even if not being + trained. Outside of generator pre-training this needs to be + tracked to determine if the discriminator is "too good" or "not + good enough" + + Returns + ------- + loss : tf.Tensor + 0D tensor representing the loss value for the network being trained + (either generator or one of the discriminators) + loss_details : dict + Namespace of the breakdown of loss components + """ + hi_res_gen = self._combine_loss_input(hi_res_true, hi_res_gen) + + if hi_res_gen.shape != hi_res_true.shape: + msg = ( + 'The tensor shapes of the synthetic output {} and ' + 'true high res {} did not have matching shape! ' + 'Check the spatiotemporal enhancement multipliers in your ' + 'your model config and data handlers.'.format( + hi_res_gen.shape, hi_res_true.shape + ) + ) + logger.error(msg) + raise RuntimeError(msg) + + disc_out_true = self._tf_discriminate(hi_res_true) + disc_out_gen = self._tf_discriminate(hi_res_gen) + + loss_details = {} + loss = None + + if compute_disc or train_disc: + loss_details['loss_disc'] = self.calc_loss_disc( + disc_out_true=disc_out_true, disc_out_gen=disc_out_gen + ) + + if train_gen: + loss_gen_content, loss_gen_content_details = ( + self.calc_loss_gen_content(hi_res_true, hi_res_gen) + ) + loss_gen_advers = self.calc_loss_disc( + disc_out_true=disc_out_gen, disc_out_gen=disc_out_true + ) + loss = loss_gen_content + weight_gen_advers * loss_gen_advers + loss_details['loss_gen'] = loss + loss_details['loss_gen_content'] = loss_gen_content + loss_details['loss_gen_advers'] = loss_gen_advers + loss_details.update(loss_gen_content_details) + + elif train_disc: + loss = loss_details['loss_disc'] + + return loss, loss_details + def calc_val_loss(self, batch_handler, weight_gen_advers): """Calculate the validation loss at the current state of model training @@ -925,11 +918,7 @@ def calc_val_loss(self, batch_handler, weight_gen_advers): hi_res_exo = self.get_hr_exo_input(batch.high_res) hi_res_gen = self._tf_generate(batch.low_res, hi_res_exo) _, v_loss_details = self.calc_loss( - batch.high_res, - hi_res_gen, - weight_gen_advers=weight_gen_advers, - train_gen=False, - train_disc=False, + batch.high_res, hi_res_gen, weight_gen_advers=weight_gen_advers ) self._val_record = self.update_loss_details( self._val_record, @@ -1004,6 +993,7 @@ def _train_batch( optimizer=self.optimizer, train_gen=True, train_disc=False, + compute_disc=train_disc, multi_gpu=multi_gpu, ) @@ -1024,7 +1014,9 @@ def _train_batch( b_loss_details['disc_train_frac'] = float(trained_disc) return b_loss_details - def _post_batch(self, ib, b_loss_details, loss_mean_window, n_batches): + def _post_batch( + self, ib, b_loss_details, loss_mean_window, n_batches, previous_means + ): """Update loss details after the current batch and write to log. Parameters @@ -1037,12 +1029,19 @@ def _post_batch(self, ib, b_loss_details, loss_mean_window, n_batches): Number of batches to use in the running loss means n_batches : int Number of batches in an epoch + previous_means : dict + Dictionary of previous loss means over the loss_mean_window Returns ------- loss_means : dict Dictionary of running loss means """ + # set default values for when either disc / gen is not trained for the + # last batch + for key, val in previous_means.items(): + if key.startswith('train_'): + b_loss_details.setdefault(key.replace('train_', ''), val) self._train_record = self.update_loss_details( self._train_record, @@ -1171,7 +1170,11 @@ def _train_epoch( ) loss_means = self._post_batch( - ib, b_loss_details, loss_mean_window, len(batch_handler) + ib, + b_loss_details, + loss_mean_window, + len(batch_handler), + loss_means, ) self.total_batches += len(batch_handler) diff --git a/sup3r/models/solar_cc.py b/sup3r/models/solar_cc.py index 300a5de12f..66bb9e16f6 100644 --- a/sup3r/models/solar_cc.py +++ b/sup3r/models/solar_cc.py @@ -97,6 +97,7 @@ def calc_loss( weight_gen_advers=0.001, train_gen=True, train_disc=False, + compute_disc=False, ): """Calculate the GAN loss function using generated and true high resolution data. @@ -115,6 +116,11 @@ def calc_loss( True if generator is being trained, then loss=loss_gen train_disc : bool True if disc is being trained, then loss=loss_disc + compute_disc : bool + True if discriminator loss should be computed, even if not being + trained. Outside of generator pre-training this needs to be + tracked to determine if the discriminator is "too good" or "not + good enough" Returns ------- @@ -168,30 +174,9 @@ def calc_loss( for x in range(0, 24 * n_days, 24) ] - # sample only daylight hours for disc training and gen content loss disc_out_true = [] disc_out_gen = [] loss_gen_content = 0.0 - ziter = zip(sub_day_slices, point_loss_slices, day_24h_slices) - for tslice_sub, tslice_ploss, tslice_24h in ziter: - hr_true_sub = hi_res_true[:, :, :, tslice_sub, :] - hr_gen_24h = hi_res_gen[:, :, :, tslice_24h, :] - hr_true_ploss = hi_res_true[:, :, :, tslice_ploss, :] - hr_gen_ploss = hi_res_gen[:, :, :, tslice_ploss, :] - - hr_true_mean = tf.math.reduce_mean(hr_true_sub, axis=3) - hr_gen_mean = tf.math.reduce_mean(hr_gen_24h, axis=3) - - gen_c_sub, gen_c_sub_details = self.calc_loss_gen_content( - hr_true_ploss, hr_gen_ploss - ) - gen_c_24h, gen_c_24h_details = self.calc_loss_gen_content( - hr_true_mean, hr_gen_mean - ) - loss_gen_content = gen_c_sub + gen_c_24h - - disc_t = self._tf_discriminate(hr_true_sub) - disc_out_true.append(disc_t) # Randomly sample daylight windows from generated data. Better than # strided samples covering full day because the random samples will @@ -204,32 +189,64 @@ def calc_loss( disc_g = self._tf_discriminate(hi_res_gen[:, :, :, t0:t1, :]) disc_out_gen.append(disc_g) + # sample only daylight hours for disc training + ziter = zip(sub_day_slices, point_loss_slices, day_24h_slices) + for tslice_sub, _, _ in ziter: + hr_true_sub = hi_res_true[:, :, :, tslice_sub, :] + disc_t = self._tf_discriminate(hr_true_sub) + disc_out_true.append(disc_t) + disc_out_true = tf.concat([disc_out_true], axis=0) disc_out_gen = tf.concat([disc_out_gen], axis=0) - loss_disc = self.calc_loss_disc(disc_out_true, disc_out_gen) - - loss_gen_content /= len(sub_day_slices) - loss_gen_advers = self.calc_loss_gen_advers(disc_out_gen) - loss_gen = loss_gen_content + weight_gen_advers * loss_gen_advers + loss_details = {} loss = None + + if compute_disc or train_disc: + loss_details['loss_disc'] = self.calc_loss_disc( + disc_out_true=disc_out_true, disc_out_gen=disc_out_gen + ) + if train_gen: - loss = loss_gen + # sample only daylight hours for content loss + ziter = zip(sub_day_slices, point_loss_slices, day_24h_slices) + for tslice_sub, tslice_ploss, tslice_24h in ziter: + hr_true_sub = hi_res_true[:, :, :, tslice_sub, :] + hr_gen_24h = hi_res_gen[:, :, :, tslice_24h, :] + hr_true_ploss = hi_res_true[:, :, :, tslice_ploss, :] + hr_gen_ploss = hi_res_gen[:, :, :, tslice_ploss, :] + + hr_true_mean = tf.math.reduce_mean(hr_true_sub, axis=3) + hr_gen_mean = tf.math.reduce_mean(hr_gen_24h, axis=3) + + gen_c_sub, gen_c_sub_details = self.calc_loss_gen_content( + hr_true_ploss, hr_gen_ploss + ) + gen_c_24h, gen_c_24h_details = self.calc_loss_gen_content( + hr_true_mean, hr_gen_mean + ) + loss_gen_content += (gen_c_sub + gen_c_24h) / len( + sub_day_slices + ) + for k, v in gen_c_sub_details.items(): + loss_details[f'c_sub_{k}'] = loss_details.get( + f'c_sub_{k}', 0 + ) + v / len(sub_day_slices) + for k, v in gen_c_24h_details.items(): + loss_details[f'c_24h_{k}'] = loss_details.get( + f'c_24h_{k}', 0 + ) + v / len(sub_day_slices) + + loss_gen_advers = self.calc_loss_disc( + disc_out_true=disc_out_gen, disc_out_gen=disc_out_true + ) + loss = loss_gen_content + weight_gen_advers * loss_gen_advers + loss_details['loss_gen'] = loss + loss_details['loss_gen_content'] = loss_gen_content + loss_details['loss_gen_advers'] = loss_gen_advers + elif train_disc: - loss = loss_disc - - loss_details = { - 'loss_gen': loss_gen, - 'loss_gen_content': loss_gen_content, - 'loss_gen_advers': loss_gen_advers, - 'loss_disc': loss_disc, - } - loss_details.update( - {f'c_sub_{k}': v for k, v in gen_c_sub_details.items()} - ) - loss_details.update( - {f'c_24h_{k}': v for k, v in gen_c_24h_details.items()} - ) + loss = loss_details['loss_disc'] return loss, loss_details diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index 9e94a6660a..3616e23317 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -386,7 +386,11 @@ def get_node_cmd(cls, config): @classmethod def _constant_output_check(cls, out_data, allowed_const): """Check if forward pass output is constant. This can happen when the - chunk going through the forward pass is too big. + chunk going through the forward pass is too big. This is due to a + tensorflow padding bug, with the padding mode set to 'reflect'. With + the currently preferred tensorflow version (2.15.1) this results in + scrambled output rather than constant. + https://github.com/tensorflow/tensorflow/issues/91027 Parameters ---------- diff --git a/sup3r/preprocessing/derivers/utilities.py b/sup3r/preprocessing/derivers/utilities.py index d7b1776b21..10759ad064 100644 --- a/sup3r/preprocessing/derivers/utilities.py +++ b/sup3r/preprocessing/derivers/utilities.py @@ -159,7 +159,8 @@ def transform_rotate_wind(ws, wd, lat_lon): (spatial_1, spatial_2, temporal) wd : Union[np.ndarray, da.core.Array] 3D array of high res winddirection data. Angle is in degrees and - measured relative to the south_north direction. + measured clockwise from the north direction. This is direction wind is + coming from. (spatial_1, spatial_2, temporal) lat_lon : Union[np.ndarray, da.core.Array] 3D array of lat lon