Skip to content

Commit 5b905ae

Browse files
committed
hot fix: wasn't updating b_loss_details correctly when both gen and disc are trained.
1 parent a386e51 commit 5b905ae

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

sup3r/models/base.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -982,7 +982,7 @@ def _train_batch(
982982

983983
trained_gen = False
984984
trained_disc = False
985-
b_loss_details = {}
985+
loss_details = {}
986986
if only_gen or (train_gen and not gen_too_good):
987987
trained_gen = True
988988
b_loss_details = self.timer(self.run_gradient_descent)(
@@ -996,6 +996,7 @@ def _train_batch(
996996
compute_disc=train_disc,
997997
multi_gpu=multi_gpu,
998998
)
999+
loss_details.update(b_loss_details)
9991000

10001001
if only_disc or (train_disc and not disc_too_good):
10011002
trained_disc = True
@@ -1009,10 +1010,11 @@ def _train_batch(
10091010
train_disc=True,
10101011
multi_gpu=multi_gpu,
10111012
)
1013+
loss_details.update(b_loss_details)
10121014

1013-
b_loss_details['gen_train_frac'] = float(trained_gen)
1014-
b_loss_details['disc_train_frac'] = float(trained_disc)
1015-
return b_loss_details
1015+
loss_details['gen_train_frac'] = float(trained_gen)
1016+
loss_details['disc_train_frac'] = float(trained_disc)
1017+
return loss_details
10161018

10171019
def _post_batch(
10181020
self, ib, b_loss_details, loss_mean_window, n_batches, previous_means

0 commit comments

Comments
 (0)