-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathadversarial.py
123 lines (105 loc) · 4.52 KB
/
adversarial.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import argparse
import os.path
import numpy as np
import torch.nn
import torch.utils.data
import torch.optim
import wandb
import tqdm
from utils import get_datasets_and_generator, parse_cli, generate
np.random.seed(0)
torch.manual_seed(0)
def train(args):
# Define datasets and generator model.
uniform_dataloader, normal_dataloader, generator = get_datasets_and_generator(args)
generator = generator.to(device)
# Define discriminator model (simple fully-connected with LeakyReLUs).
discriminator = torch.nn.Sequential(
torch.nn.Linear(args.out_shape, 5),
torch.nn.LeakyReLU(),
torch.nn.Linear(5, 5),
torch.nn.LeakyReLU(),
torch.nn.Linear(5, 5),
torch.nn.LeakyReLU(),
torch.nn.Linear(5, 1),
torch.nn.Sigmoid()
).to(device)
# Weights and Biases (wandb) stuff.
wandb.init(project='gan_poc')
wandb.config.update({'epochs': args.epochs,
'batch_size': args.batch_size,
'learning_rate': args.learning_rate,
'input_shape': args.in_shape,
'output_shape': args.out_shape,
})
wandb.watch((generator, discriminator))
# Define loss criterion (binary cross-entropy).
criterion = torch.nn.BCELoss()
# Define optimizer for generator and discriminator (Adam).
optimizerG = torch.optim.Adam(generator.parameters(), lr=args.learning_rate)
optimizerD = torch.optim.Adam(discriminator.parameters(), lr=2 * args.learning_rate)
real_label = 1.0
fake_label = 0.
for epoch in range(args.epochs):
print('Epoch:', epoch)
progress_bar = tqdm.tqdm(zip(uniform_dataloader, normal_dataloader),
total=len(uniform_dataloader))
for input_, target in progress_bar:
input_, target = input_.float().to(device), target.float().to(device)
# Format batch
label = torch.full_like(target[:, 0, 0], real_label, device=device)
############################
# (1) Update G network: maximize log(D(G(z)))
###########################
optimizerG.zero_grad()
fake = generator(input_) # generate fake batch.
# Perform a forward pass of all-fake batch through D.
output = discriminator(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, label)
# Calculate gradients for G
errG.backward()
D_G_z1 = output.mean().item()
# Update G
optimizerG.step()
############################
# (2) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
optimizerD.zero_grad()
# Forward pass real batch through D
output = discriminator(target).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)
# Calculate gradients for D in backward pass
errD_real.backward()
D_x = output.mean().item()
# Train with fake batch
label.fill_(fake_label)
# Forward pass fake batch through D
output = discriminator(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
# Calculate the gradients for this batch
errD_fake.backward()
D_G_z2 = output.mean().item()
# Add the gradients from the all-real and all-fake batches
errD = (errD_real + errD_fake) / 2
# Update D
optimizerD.step()
progress = {'Discriminator loss': errD.item(), 'Generator loss': errG.item()}
progress_bar.set_postfix(progress)
progress['D_x'] = D_x
progress['D_G_z1'] = D_G_z1
progress['D_G_z2'] = D_G_z2
wandb.log(progress)
print('Saving model.')
torch.save(generator.state_dict(), args.model_path)
torch.save(generator.state_dict(), os.path.join(wandb.run.dir, 'model.pt'))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train generator with adversarial training '
'or generate samples')
args = parse_cli(parser, train_func=train, generate_func=generate)
cuda = torch.cuda.is_available()
device = 'cuda:0' if cuda else 'cpu'
# Call appropriate function (train or generate).
args.func(args)