-
Notifications
You must be signed in to change notification settings - Fork 77
/
Copy pathkernelGAN.py
127 lines (108 loc) · 5.58 KB
/
kernelGAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import loss
import networks
import torch.nn.functional as F
from util import save_final_kernel, run_zssr, post_process_k
class KernelGAN:
# Constraint co-efficients
lambda_sum2one = 0.5
lambda_bicubic = 5
lambda_boundaries = 0.5
lambda_centralized = 0
lambda_sparse = 0
def __init__(self, conf):
# Acquire configuration
self.conf = conf
# Define the GAN
self.G = networks.Generator(conf).cuda()
self.D = networks.Discriminator(conf).cuda()
# Calculate D's input & output shape according to the shaving done by the networks
self.d_input_shape = self.G.output_size
self.d_output_shape = self.d_input_shape - self.D.forward_shave
# Input tensors
self.g_input = torch.FloatTensor(1, 3, conf.input_crop_size, conf.input_crop_size).cuda()
self.d_input = torch.FloatTensor(1, 3, self.d_input_shape, self.d_input_shape).cuda()
# The kernel G is imitating
self.curr_k = torch.FloatTensor(conf.G_kernel_size, conf.G_kernel_size).cuda()
# Losses
self.GAN_loss_layer = loss.GANLoss(d_last_layer_size=self.d_output_shape).cuda()
self.bicubic_loss = loss.DownScaleLoss(scale_factor=conf.scale_factor).cuda()
self.sum2one_loss = loss.SumOfWeightsLoss().cuda()
self.boundaries_loss = loss.BoundariesLoss(k_size=conf.G_kernel_size).cuda()
self.centralized_loss = loss.CentralizedLoss(k_size=conf.G_kernel_size, scale_factor=conf.scale_factor).cuda()
self.sparse_loss = loss.SparsityLoss().cuda()
self.loss_bicubic = 0
# Define loss function
self.criterionGAN = self.GAN_loss_layer.forward
# Initialize networks weights
self.G.apply(networks.weights_init_G)
self.D.apply(networks.weights_init_D)
# Optimizers
self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=conf.g_lr, betas=(conf.beta1, 0.999))
self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=conf.d_lr, betas=(conf.beta1, 0.999))
print('*' * 60 + '\nSTARTED KernelGAN on: \"%s\"...' % conf.input_image_path)
# noinspection PyUnboundLocalVariable
def calc_curr_k(self):
"""given a generator network, the function calculates the kernel it is imitating"""
delta = torch.Tensor([1.]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).cuda()
for ind, w in enumerate(self.G.parameters()):
curr_k = F.conv2d(delta, w, padding=self.conf.G_kernel_size - 1) if ind == 0 else F.conv2d(curr_k, w)
self.curr_k = curr_k.squeeze().flip([0, 1])
def train(self, g_input, d_input):
self.set_input(g_input, d_input)
self.train_g()
self.train_d()
def set_input(self, g_input, d_input):
self.g_input = g_input.contiguous()
self.d_input = d_input.contiguous()
def train_g(self):
# Zeroize gradients
self.optimizer_G.zero_grad()
# Generator forward pass
g_pred = self.G.forward(self.g_input)
# Pass Generators output through Discriminator
d_pred_fake = self.D.forward(g_pred)
# Calculate generator loss, based on discriminator prediction on generator result
loss_g = self.criterionGAN(d_last_layer=d_pred_fake, is_d_input_real=True)
# Sum all losses
total_loss_g = loss_g + self.calc_constraints(g_pred)
# Calculate gradients
total_loss_g.backward()
# Update weights
self.optimizer_G.step()
def calc_constraints(self, g_pred):
# Calculate K which is equivalent to G
self.calc_curr_k()
# Calculate constraints
self.loss_bicubic = self.bicubic_loss.forward(g_input=self.g_input, g_output=g_pred)
loss_boundaries = self.boundaries_loss.forward(kernel=self.curr_k)
loss_sum2one = self.sum2one_loss.forward(kernel=self.curr_k)
loss_centralized = self.centralized_loss.forward(kernel=self.curr_k)
loss_sparse = self.sparse_loss.forward(kernel=self.curr_k)
# Apply constraints co-efficients
return self.loss_bicubic * self.lambda_bicubic + loss_sum2one * self.lambda_sum2one + \
loss_boundaries * self.lambda_boundaries + loss_centralized * self.lambda_centralized + \
loss_sparse * self.lambda_sparse
def train_d(self):
# Zeroize gradients
self.optimizer_D.zero_grad()
# Discriminator forward pass over real example
d_pred_real = self.D.forward(self.d_input)
# Discriminator forward pass over fake example (generated by generator)
# Note that generator result is detached so that gradients are not propagating back through generator
g_output = self.G.forward(self.g_input)
d_pred_fake = self.D.forward((g_output + torch.randn_like(g_output) / 255.).detach())
# Calculate discriminator loss
loss_d_fake = self.criterionGAN(d_pred_fake, is_d_input_real=False)
loss_d_real = self.criterionGAN(d_pred_real, is_d_input_real=True)
loss_d = (loss_d_fake + loss_d_real) * 0.5
# Calculate gradients, note that gradients are not propagating back through generator
loss_d.backward()
# Update weights, note that only discriminator weights are updated (by definition of the D optimizer)
self.optimizer_D.step()
def finish(self):
final_kernel = post_process_k(self.curr_k, n=self.conf.n_filtering)
save_final_kernel(final_kernel, self.conf)
print('KernelGAN estimation complete!')
run_zssr(final_kernel, self.conf)
print('FINISHED RUN (see --%s-- folder)\n' % self.conf.output_dir_path + '*' * 60 + '\n\n')