@@ -46,6 +46,13 @@ class CycleGAN():
46
46
47
47
def input_setup (self ):
48
48
49
+ '''
50
+ This function basically setup variables for taking image input.
51
+
52
+ filenames_A/filenames_B -> takes the list of all training images
53
+ self.image_A/self.image_B -> Input image with each values ranging from [-1,1]
54
+ '''
55
+
49
56
filenames_A = tf .train .match_filenames_once ("./input/horse2zebra/trainA/*.jpg" )
50
57
self .queue_length_A = tf .size (filenames_A )
51
58
filenames_B = tf .train .match_filenames_once ("./input/horse2zebra/trainB/*.jpg" )
@@ -65,16 +72,21 @@ def input_setup(self):
65
72
66
73
def input_read (self , sess ):
67
74
75
+
76
+ '''
77
+ It reads the input into from the image folder.
78
+
79
+ self.fake_images_A/self.fake_images_B -> List of generated images used for calculation of loss function of Discriminator
80
+ self.A_input/self.B_input -> Stores all the training images in python list
81
+ '''
82
+
68
83
# Loading images into the tensors
69
84
coord = tf .train .Coordinator ()
70
85
threads = tf .train .start_queue_runners (coord = coord )
71
86
72
87
num_files_A = sess .run (self .queue_length_A )
73
88
num_files_B = sess .run (self .queue_length_B )
74
89
75
- images_A = []
76
- images_B = []
77
-
78
90
self .fake_images_A = np .zeros ((pool_size ,1 ,img_height , img_width , img_layer ))
79
91
self .fake_images_B = np .zeros ((pool_size ,1 ,img_height , img_width , img_layer ))
80
92
@@ -99,6 +111,14 @@ def input_read(self, sess):
99
111
100
112
def model_setup (self ):
101
113
114
+ ''' This function sets up the model to train
115
+
116
+ self.input_A/self.input_B -> Set of training images.
117
+ self.fake_A/self.fake_B -> Generated images by corresponding generator of input_A and input_B
118
+ self.lr -> Learning rate variable
119
+ self.cyc_A/ self.cyc_B -> Images generated after feeding self.fake_A/self.fake_B to corresponding generator. This is use to calcualte cyclic loss
120
+ '''
121
+
102
122
self .input_A = tf .placeholder (tf .float32 , [batch_size , img_width , img_height , img_layer ], name = "input_A" )
103
123
self .input_B = tf .placeholder (tf .float32 , [batch_size , img_width , img_height , img_layer ], name = "input_B" )
104
124
@@ -131,6 +151,13 @@ def model_setup(self):
131
151
132
152
def loss_calc (self ):
133
153
154
+ ''' In this function we are defining the variables for loss calcultions and traning model
155
+
156
+ d_loss_A/d_loss_B -> loss for discriminator A/B
157
+ g_loss_A/g_loss_B -> loss for generator A/B
158
+ *_trainer -> Variaous trainer for above loss functions
159
+ *_summ -> Summary variables for above loss functions'''
160
+
134
161
cyc_loss = tf .reduce_mean (tf .abs (self .input_A - self .cyc_A )) + tf .reduce_mean (tf .abs (self .input_B - self .cyc_B ))
135
162
136
163
disc_loss_A = tf .reduce_mean (tf .squared_difference (self .fake_rec_A ,1 ))
@@ -181,6 +208,9 @@ def save_training_images(self, sess, epoch):
181
208
imsave ("./output/imgs/inputB_" + str (epoch ) + "_" + str (i )+ ".jpg" ,((self .B_input [i ][0 ]+ 1 )* 127.5 ).astype (np .uint8 ))
182
209
183
210
def fake_image_pool (self , num_fakes , fake , fake_pool ):
211
+ ''' This function saves the generated image to corresponding pool of images.
212
+ In starting. It keeps on feeling the pool till it is full and then randomly selects an
213
+ already stored image and replace it with new one.'''
184
214
185
215
if (num_fakes < pool_size ):
186
216
fake_pool [num_fakes ] = fake
@@ -199,6 +229,9 @@ def fake_image_pool(self, num_fakes, fake, fake_pool):
199
229
def train (self ):
200
230
201
231
232
+ ''' Training Function '''
233
+
234
+
202
235
# Load Dataset from the dataset folder
203
236
self .input_setup ()
204
237
@@ -283,6 +316,9 @@ def train(self):
283
316
284
317
def test (self ):
285
318
319
+
320
+ ''' Testing Function'''
321
+
286
322
print ("Testing the results" )
287
323
288
324
self .input_setup ()
0 commit comments