diff --git a/ML/Pytorch/GANs/2. DCGAN/train.py b/ML/Pytorch/GANs/2. DCGAN/train.py index aa943682..3980aa79 100644 --- a/ML/Pytorch/GANs/2. DCGAN/train.py +++ b/ML/Pytorch/GANs/2. DCGAN/train.py @@ -77,7 +77,7 @@ loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake)) loss_disc = (loss_disc_real + loss_disc_fake) / 2 disc.zero_grad() - loss_disc.backward() + loss_disc.backward(retain_graph=True) opt_disc.step() ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))