Skip to content

Commit 1cea694

Browse files
committed
Added the doc string in functions
1 parent 4585c9d commit 1cea694

File tree

2 files changed

+53
-3
lines changed

2 files changed

+53
-3
lines changed

Diff for: download_datasets.sh

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
FILE=$1
2+
3+
if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then
4+
echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos"
5+
exit 1
6+
fi
7+
8+
URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
9+
ZIP_FILE=./input/$FILE.zip
10+
TARGET_DIR=./input/$FILE/
11+
wget -N $URL -O $ZIP_FILE
12+
mkdir $TARGET_DIR
13+
unzip $ZIP_FILE -d ./input/
14+
rm $ZIP_FILE

Diff for: main.py

+39-3
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ class CycleGAN():
4646

4747
def input_setup(self):
4848

49+
'''
50+
This function basically setup variables for taking image input.
51+
52+
filenames_A/filenames_B -> takes the list of all training images
53+
self.image_A/self.image_B -> Input image with each values ranging from [-1,1]
54+
'''
55+
4956
filenames_A = tf.train.match_filenames_once("./input/horse2zebra/trainA/*.jpg")
5057
self.queue_length_A = tf.size(filenames_A)
5158
filenames_B = tf.train.match_filenames_once("./input/horse2zebra/trainB/*.jpg")
@@ -65,16 +72,21 @@ def input_setup(self):
6572

6673
def input_read(self, sess):
6774

75+
76+
'''
77+
It reads the input into from the image folder.
78+
79+
self.fake_images_A/self.fake_images_B -> List of generated images used for calculation of loss function of Discriminator
80+
self.A_input/self.B_input -> Stores all the training images in python list
81+
'''
82+
6883
# Loading images into the tensors
6984
coord = tf.train.Coordinator()
7085
threads = tf.train.start_queue_runners(coord=coord)
7186

7287
num_files_A = sess.run(self.queue_length_A)
7388
num_files_B = sess.run(self.queue_length_B)
7489

75-
images_A = []
76-
images_B = []
77-
7890
self.fake_images_A = np.zeros((pool_size,1,img_height, img_width, img_layer))
7991
self.fake_images_B = np.zeros((pool_size,1,img_height, img_width, img_layer))
8092

@@ -99,6 +111,14 @@ def input_read(self, sess):
99111

100112
def model_setup(self):
101113

114+
''' This function sets up the model to train
115+
116+
self.input_A/self.input_B -> Set of training images.
117+
self.fake_A/self.fake_B -> Generated images by corresponding generator of input_A and input_B
118+
self.lr -> Learning rate variable
119+
self.cyc_A/ self.cyc_B -> Images generated after feeding self.fake_A/self.fake_B to corresponding generator. This is use to calcualte cyclic loss
120+
'''
121+
102122
self.input_A = tf.placeholder(tf.float32, [batch_size, img_width, img_height, img_layer], name="input_A")
103123
self.input_B = tf.placeholder(tf.float32, [batch_size, img_width, img_height, img_layer], name="input_B")
104124

@@ -131,6 +151,13 @@ def model_setup(self):
131151

132152
def loss_calc(self):
133153

154+
''' In this function we are defining the variables for loss calcultions and traning model
155+
156+
d_loss_A/d_loss_B -> loss for discriminator A/B
157+
g_loss_A/g_loss_B -> loss for generator A/B
158+
*_trainer -> Variaous trainer for above loss functions
159+
*_summ -> Summary variables for above loss functions'''
160+
134161
cyc_loss = tf.reduce_mean(tf.abs(self.input_A-self.cyc_A)) + tf.reduce_mean(tf.abs(self.input_B-self.cyc_B))
135162

136163
disc_loss_A = tf.reduce_mean(tf.squared_difference(self.fake_rec_A,1))
@@ -181,6 +208,9 @@ def save_training_images(self, sess, epoch):
181208
imsave("./output/imgs/inputB_"+ str(epoch) + "_" + str(i)+".jpg",((self.B_input[i][0]+1)*127.5).astype(np.uint8))
182209

183210
def fake_image_pool(self, num_fakes, fake, fake_pool):
211+
''' This function saves the generated image to corresponding pool of images.
212+
In starting. It keeps on feeling the pool till it is full and then randomly selects an
213+
already stored image and replace it with new one.'''
184214

185215
if(num_fakes < pool_size):
186216
fake_pool[num_fakes] = fake
@@ -199,6 +229,9 @@ def fake_image_pool(self, num_fakes, fake, fake_pool):
199229
def train(self):
200230

201231

232+
''' Training Function '''
233+
234+
202235
# Load Dataset from the dataset folder
203236
self.input_setup()
204237

@@ -283,6 +316,9 @@ def train(self):
283316

284317
def test(self):
285318

319+
320+
''' Testing Function'''
321+
286322
print("Testing the results")
287323

288324
self.input_setup()

0 commit comments

Comments
 (0)