Added relativistic discriminator loss used in ESRGAN paper.#261
Conversation
8ab0206 to
06c0990
Compare
|
@bnb32 Starting review - to clarify, you only implemented the feature "Relativistic average GAN (RaGAN) [20], which Also, does this completely revise the previous disc method or is it an option? |
|
@grantbuster Yeah, the RaGAN disc loss is what I added. This changes the previous method. It allows us to use |
…g half right should give 0.5 loss.
06c0990 to
3c9ada8
Compare
grantbuster
left a comment
There was a problem hiding this comment.
Minor suggestion but LGTM
| if train_gen: | ||
| loss = loss_gen | ||
| elif train_disc: | ||
| loss = loss_disc |
There was a problem hiding this comment.
This seems really inefficient to run all loss calculations and then only output one of them. Why dont we wrap the actual loss calculations in the if statement?
There was a problem hiding this comment.
Good call. We've had this inefficient setup for a while.
There was a problem hiding this comment.
yeah no blame just reading our old code and thinking well that could have been done better haha
There was a problem hiding this comment.
@grantbuster Actually, we need to compute the disc loss every batch to track whether to train it or not. We can skip the gen loss calcs though.
There was a problem hiding this comment.
Gotcha, makes sense. Maybe note this in-line comment so future-us remembers haha
There was a problem hiding this comment.
@grantbuster Nvm, how's this solution? d122ace
| loss_gen_content, loss_gen_content_details = ( | ||
| self.calc_loss_gen_content(hi_res_true, hi_res_gen) | ||
| ) | ||
| loss_gen_advers = self.calc_loss_disc( | ||
| disc_out_true=disc_out_gen, disc_out_gen=disc_out_true | ||
| ) | ||
| loss_gen = loss_gen_content + weight_gen_advers * loss_gen_advers | ||
| loss_disc = self.calc_loss_disc( | ||
| disc_out_true=disc_out_true, disc_out_gen=disc_out_gen | ||
| ) | ||
|
|
||
| loss = None | ||
| if train_gen: | ||
| loss = loss_gen | ||
| elif train_disc: | ||
| loss = loss_disc | ||
|
|
||
| loss_details = { | ||
| 'loss_gen': loss_gen, | ||
| 'loss_gen_content': loss_gen_content, | ||
| 'loss_gen_advers': loss_gen_advers, | ||
| 'loss_disc': loss_disc, | ||
| } |
There was a problem hiding this comment.
| loss_gen_content, loss_gen_content_details = ( | |
| self.calc_loss_gen_content(hi_res_true, hi_res_gen) | |
| ) | |
| loss_gen_advers = self.calc_loss_disc( | |
| disc_out_true=disc_out_gen, disc_out_gen=disc_out_true | |
| ) | |
| loss_gen = loss_gen_content + weight_gen_advers * loss_gen_advers | |
| loss_disc = self.calc_loss_disc( | |
| disc_out_true=disc_out_true, disc_out_gen=disc_out_gen | |
| ) | |
| loss = None | |
| if train_gen: | |
| loss = loss_gen | |
| elif train_disc: | |
| loss = loss_disc | |
| loss_details = { | |
| 'loss_gen': loss_gen, | |
| 'loss_gen_content': loss_gen_content, | |
| 'loss_gen_advers': loss_gen_advers, | |
| 'loss_disc': loss_disc, | |
| } | |
| loss_details = {} | |
| if train_gen: | |
| loss_gen_content, loss_gen_content_details = ( | |
| self.calc_loss_gen_content(hi_res_true, hi_res_gen) | |
| ) | |
| loss_gen_advers = self.calc_loss_disc( | |
| disc_out_true=disc_out_gen, disc_out_gen=disc_out_true | |
| ) | |
| loss = loss_gen_content + weight_gen_advers * loss_gen_advers | |
| loss_details['loss_gen'] = loss | |
| loss_details['loss_gen_content'] = loss_gen_content, | |
| loss_details['loss_gen_advers'] = loss_gen_advers, | |
| elif train_disc: | |
| loss = self.calc_loss_disc( | |
| disc_out_true=disc_out_true, disc_out_gen=disc_out_gen | |
| ) | |
| loss_details['loss_disc'] = loss | |
| chunk going through the forward pass is too big. | ||
| chunk going through the forward pass is too big. This is due to a | ||
| tensorflow padding bug, with the padding mode set to 'reflect'. | ||
| https://github.com/tensorflow/tensorflow/issues/91027 |
There was a problem hiding this comment.
You might note that in our current TF version 2.15 this actually just results in scrambled outputs.
There was a problem hiding this comment.
(which is very hard to detect)
… that tf 2.15.1 padding bug produces scrambled output instead of constant.
…mputing disc loss.
… loss calcs for disc and gen content.
Added relativistic discriminator loss used in ESRGAN paper.
This changes the previous disc loss calc (and adversarial loss) to use the relativistic versions described in the ESRGAN paper. The symmetry enables us to remove the adversarial loss function, and instead just swap the arguments in the new disc loss calc, and also seems worth doing since this improved on the SRGAN framework.
I'm also finding this to result in much stabler training, btw.