Skip to content

Added relativistic discriminator loss used in ESRGAN paper.#261

Merged
bnb32 merged 7 commits intomainfrom
bnb/relativistic_disc
Apr 17, 2025
Merged

Added relativistic discriminator loss used in ESRGAN paper.#261
bnb32 merged 7 commits intomainfrom
bnb/relativistic_disc

Conversation

@bnb32
Copy link
Copy Markdown
Collaborator

@bnb32 bnb32 commented Mar 6, 2025

This changes the previous disc loss calc (and adversarial loss) to use the relativistic versions described in the ESRGAN paper. The symmetry enables us to remove the adversarial loss function, and instead just swap the arguments in the new disc loss calc, and also seems worth doing since this improved on the SRGAN framework.

I'm also finding this to result in much stabler training, btw.

@bnb32 bnb32 force-pushed the bnb/relativistic_disc branch from 8ab0206 to 06c0990 Compare March 7, 2025 23:49
@bnb32 bnb32 marked this pull request as ready for review March 28, 2025 15:24
@bnb32 bnb32 requested a review from grantbuster March 28, 2025 15:24
@grantbuster
Copy link
Copy Markdown
Member

@bnb32 Starting review - to clarify, you only implemented the feature "Relativistic average GAN (RaGAN) [20], which
learns to judge “whether one image is more realistic than the other” rather than “whether one image is real or fake”" from the Wang paper, right? Seems like Wang did a lot of different things and I'm trying to track what to pay attention to.

Also, does this completely revise the previous disc method or is it an option?

@bnb32
Copy link
Copy Markdown
Collaborator Author

bnb32 commented Apr 16, 2025

@grantbuster Yeah, the RaGAN disc loss is what I added. This changes the previous method. It allows us to use disc_loss(true, gen) for the discriminator and disc_loss(gen, true) for adversarial loss, instead of two different methods for these, so I removed the previous function for adversarial loss.

@bnb32 bnb32 force-pushed the bnb/relativistic_disc branch from 06c0990 to 3c9ada8 Compare April 16, 2025 17:51
Copy link
Copy Markdown
Member

@grantbuster grantbuster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor suggestion but LGTM

Comment thread sup3r/models/base.py Outdated
if train_gen:
loss = loss_gen
elif train_disc:
loss = loss_disc
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems really inefficient to run all loss calculations and then only output one of them. Why dont we wrap the actual loss calculations in the if statement?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call. We've had this inefficient setup for a while.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah no blame just reading our old code and thinking well that could have been done better haha

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@grantbuster Actually, we need to compute the disc loss every batch to track whether to train it or not. We can skip the gen loss calcs though.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, makes sense. Maybe note this in-line comment so future-us remembers haha

Copy link
Copy Markdown
Collaborator Author

@bnb32 bnb32 Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@grantbuster Nvm, how's this solution? d122ace

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment thread sup3r/models/base.py Outdated
Comment on lines +868 to +890
loss_gen_content, loss_gen_content_details = (
self.calc_loss_gen_content(hi_res_true, hi_res_gen)
)
loss_gen_advers = self.calc_loss_disc(
disc_out_true=disc_out_gen, disc_out_gen=disc_out_true
)
loss_gen = loss_gen_content + weight_gen_advers * loss_gen_advers
loss_disc = self.calc_loss_disc(
disc_out_true=disc_out_true, disc_out_gen=disc_out_gen
)

loss = None
if train_gen:
loss = loss_gen
elif train_disc:
loss = loss_disc

loss_details = {
'loss_gen': loss_gen,
'loss_gen_content': loss_gen_content,
'loss_gen_advers': loss_gen_advers,
'loss_disc': loss_disc,
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
loss_gen_content, loss_gen_content_details = (
self.calc_loss_gen_content(hi_res_true, hi_res_gen)
)
loss_gen_advers = self.calc_loss_disc(
disc_out_true=disc_out_gen, disc_out_gen=disc_out_true
)
loss_gen = loss_gen_content + weight_gen_advers * loss_gen_advers
loss_disc = self.calc_loss_disc(
disc_out_true=disc_out_true, disc_out_gen=disc_out_gen
)
loss = None
if train_gen:
loss = loss_gen
elif train_disc:
loss = loss_disc
loss_details = {
'loss_gen': loss_gen,
'loss_gen_content': loss_gen_content,
'loss_gen_advers': loss_gen_advers,
'loss_disc': loss_disc,
}
loss_details = {}
if train_gen:
loss_gen_content, loss_gen_content_details = (
self.calc_loss_gen_content(hi_res_true, hi_res_gen)
)
loss_gen_advers = self.calc_loss_disc(
disc_out_true=disc_out_gen, disc_out_gen=disc_out_true
)
loss = loss_gen_content + weight_gen_advers * loss_gen_advers
loss_details['loss_gen'] = loss
loss_details['loss_gen_content'] = loss_gen_content,
loss_details['loss_gen_advers'] = loss_gen_advers,
elif train_disc:
loss = self.calc_loss_disc(
disc_out_true=disc_out_true, disc_out_gen=disc_out_gen
)
loss_details['loss_disc'] = loss

chunk going through the forward pass is too big.
chunk going through the forward pass is too big. This is due to a
tensorflow padding bug, with the padding mode set to 'reflect'.
https://github.com/tensorflow/tensorflow/issues/91027
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might note that in our current TF version 2.15 this actually just results in scrambled outputs.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(which is very hard to detect)

@bnb32 bnb32 merged commit a386e51 into main Apr 17, 2025
12 checks passed
@bnb32 bnb32 deleted the bnb/relativistic_disc branch April 17, 2025 15:33
github-actions Bot pushed a commit that referenced this pull request Apr 17, 2025
Added relativistic discriminator loss used in ESRGAN paper.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants