|
| 1 | +from tools.ops import * |
| 2 | +from tools.utils import * |
| 3 | +from glob import glob |
| 4 | +import time |
| 5 | +import numpy as np |
| 6 | +from net import generator |
| 7 | +from net.discriminator import D_net |
| 8 | +from tools.data_loader import ImageGenerator |
| 9 | +from tools.vgg19 import Vgg19 |
| 10 | + |
| 11 | +class AnimeGANv2(object) : |
| 12 | + def __init__(self, sess, args): |
| 13 | + self.model_name = 'AnimeGANv2' |
| 14 | + self.sess = sess |
| 15 | + self.checkpoint_dir = args.checkpoint_dir |
| 16 | + self.log_dir = args.log_dir |
| 17 | + self.dataset_name = args.dataset |
| 18 | + |
| 19 | + self.epoch = args.epoch |
| 20 | + self.init_epoch = args.init_epoch # args.epoch // 20 |
| 21 | + |
| 22 | + self.gan_type = args.gan_type |
| 23 | + self.batch_size = args.batch_size |
| 24 | + self.save_freq = args.save_freq |
| 25 | + |
| 26 | + self.init_lr = args.init_lr |
| 27 | + self.d_lr = args.d_lr |
| 28 | + self.g_lr = args.g_lr |
| 29 | + |
| 30 | + """ Weight """ |
| 31 | + self.g_adv_weight = args.g_adv_weight |
| 32 | + self.d_adv_weight = args.d_adv_weight |
| 33 | + self.con_weight = args.con_weight |
| 34 | + self.sty_weight = args.sty_weight |
| 35 | + self.color_weight = args.color_weight |
| 36 | + self.tv_weight = args.tv_weight |
| 37 | + |
| 38 | + self.training_rate = args.training_rate |
| 39 | + self.ld = args.ld |
| 40 | + |
| 41 | + self.img_size = args.img_size |
| 42 | + self.img_ch = args.img_ch |
| 43 | + |
| 44 | + """ Discriminator """ |
| 45 | + self.n_dis = args.n_dis |
| 46 | + self.ch = args.ch |
| 47 | + self.sn = args.sn |
| 48 | + |
| 49 | + self.sample_dir = os.path.join(args.sample_dir, self.model_dir) |
| 50 | + check_folder(self.sample_dir) |
| 51 | + |
| 52 | + self.real = tf.placeholder(tf.float32, [self.batch_size, self.img_size[0], self.img_size[1], self.img_ch], name='real_A') |
| 53 | + self.anime = tf.placeholder(tf.float32, [self.batch_size, self.img_size[0], self.img_size[1], self.img_ch], name='anime_A') |
| 54 | + self.anime_smooth = tf.placeholder(tf.float32, [self.batch_size, self.img_size[0], self.img_size[1], self.img_ch], name='anime_smooth_A') |
| 55 | + self.test_real = tf.placeholder(tf.float32, [1, None, None, self.img_ch], name='test_input') |
| 56 | + |
| 57 | + self.anime_gray = tf.placeholder(tf.float32, [self.batch_size, self.img_size[0], self.img_size[1], self.img_ch],name='anime_B') |
| 58 | + |
| 59 | + |
| 60 | + self.real_image_generator = ImageGenerator('./dataset/train_photo', self.img_size, self.batch_size) |
| 61 | + self.anime_image_generator = ImageGenerator('./dataset/{}'.format(self.dataset_name + '/style'), self.img_size, self.batch_size) |
| 62 | + self.anime_smooth_generator = ImageGenerator('./dataset/{}'.format(self.dataset_name + '/smooth'), self.img_size, self.batch_size) |
| 63 | + self.dataset_num = max(self.real_image_generator.num_images, self.anime_image_generator.num_images) |
| 64 | + |
| 65 | + self.vgg = Vgg19() |
| 66 | + |
| 67 | + print() |
| 68 | + print("##### Information #####") |
| 69 | + print("# gan type : ", self.gan_type) |
| 70 | + print("# dataset : ", self.dataset_name) |
| 71 | + print("# max dataset number : ", self.dataset_num) |
| 72 | + print("# batch_size : ", self.batch_size) |
| 73 | + print("# epoch : ", self.epoch) |
| 74 | + print("# init_epoch : ", self.init_epoch) |
| 75 | + print("# training image size [H, W] : ", self.img_size) |
| 76 | + print("# g_adv_weight,d_adv_weight,con_weight,sty_weight,color_weight,tv_weight : ", self.g_adv_weight,self.d_adv_weight,self.con_weight,self.sty_weight,self.color_weight,self.tv_weight) |
| 77 | + print("# init_lr,g_lr,d_lr : ", self.init_lr,self.g_lr,self.d_lr) |
| 78 | + print(f"# training_rate G -- D: {self.training_rate} : 1" ) |
| 79 | + print() |
| 80 | + |
| 81 | + ################################################################################## |
| 82 | + # Generator |
| 83 | + ################################################################################## |
| 84 | + |
| 85 | + def generator(self, x_init, reuse=False, scope="generator"): |
| 86 | + with tf.variable_scope(scope, reuse=reuse): |
| 87 | + G = generator.G_net(x_init) |
| 88 | + return G.fake |
| 89 | + |
| 90 | + ################################################################################## |
| 91 | + # Discriminator |
| 92 | + ################################################################################## |
| 93 | + |
| 94 | + def discriminator(self, x_init, reuse=False, scope="discriminator"): |
| 95 | + D = D_net(x_init, self.ch, self.n_dis, self.sn, reuse=reuse, scope=scope) |
| 96 | + return D |
| 97 | + |
| 98 | + ################################################################################## |
| 99 | + # Model |
| 100 | + ################################################################################## |
| 101 | + def gradient_panalty(self, real, fake, scope="discriminator"): |
| 102 | + if self.gan_type.__contains__('dragan') : |
| 103 | + eps = tf.random_uniform(shape=tf.shape(real), minval=0., maxval=1.) |
| 104 | + _, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3]) |
| 105 | + x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region |
| 106 | + |
| 107 | + fake = real + 0.5 * x_std * eps |
| 108 | + |
| 109 | + alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.) |
| 110 | + interpolated = real + alpha * (fake - real) |
| 111 | + |
| 112 | + logit, _= self.discriminator(interpolated, reuse=True, scope=scope) |
| 113 | + |
| 114 | + grad = tf.gradients(logit, interpolated)[0] # gradient of D(interpolated) |
| 115 | + grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm |
| 116 | + |
| 117 | + GP = 0 |
| 118 | + # WGAN - LP |
| 119 | + if self.gan_type.__contains__('lp'): |
| 120 | + GP = self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.))) |
| 121 | + |
| 122 | + elif self.gan_type.__contains__('gp') or self.gan_type == 'dragan' : |
| 123 | + GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.)) |
| 124 | + |
| 125 | + return GP |
| 126 | + |
| 127 | + def build_model(self): |
| 128 | + |
| 129 | + """ Define Generator, Discriminator """ |
| 130 | + self.generated = self.generator(self.real) |
| 131 | + self.test_generated = self.generator(self.test_real, reuse=True) |
| 132 | + |
| 133 | + |
| 134 | + anime_logit = self.discriminator(self.anime) |
| 135 | + anime_gray_logit = self.discriminator(self.anime_gray, reuse=True) |
| 136 | + |
| 137 | + generated_logit = self.discriminator(self.generated, reuse=True) |
| 138 | + smooth_logit = self.discriminator(self.anime_smooth, reuse=True) |
| 139 | + |
| 140 | + """ Define Loss """ |
| 141 | + if self.gan_type.__contains__('gp') or self.gan_type.__contains__('lp') or self.gan_type.__contains__('dragan') : |
| 142 | + GP = self.gradient_panalty(real=self.anime, fake=self.generated) |
| 143 | + else : |
| 144 | + GP = 0.0 |
| 145 | + |
| 146 | + # init pharse |
| 147 | + init_c_loss = con_loss(self.vgg, self.real, self.generated) |
| 148 | + init_loss = self.con_weight * init_c_loss |
| 149 | + |
| 150 | + self.init_loss = init_loss |
| 151 | + |
| 152 | + # gan |
| 153 | + c_loss, s_loss = con_sty_loss(self.vgg, self.real, self.anime_gray, self.generated) |
| 154 | + tv_loss = self.tv_weight * total_variation_loss(self.generated) |
| 155 | + t_loss = self.con_weight * c_loss + self.sty_weight * s_loss + color_loss(self.real,self.generated) * self.color_weight + tv_loss |
| 156 | + |
| 157 | + g_loss = self.g_adv_weight * generator_loss(self.gan_type, generated_logit) |
| 158 | + d_loss = self.d_adv_weight * discriminator_loss(self.gan_type, anime_logit, anime_gray_logit, generated_logit, smooth_logit) + GP |
| 159 | + |
| 160 | + self.Generator_loss = t_loss + g_loss |
| 161 | + self.Discriminator_loss = d_loss |
| 162 | + |
| 163 | + """ Training """ |
| 164 | + t_vars = tf.trainable_variables() |
| 165 | + G_vars = [var for var in t_vars if 'generator' in var.name] |
| 166 | + D_vars = [var for var in t_vars if 'discriminator' in var.name] |
| 167 | + |
| 168 | + self.init_optim = tf.train.AdamOptimizer(self.init_lr, beta1=0.5, beta2=0.999).minimize(self.init_loss, var_list=G_vars) |
| 169 | + self.G_optim = tf.train.AdamOptimizer(self.g_lr , beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars) |
| 170 | + self.D_optim = tf.train.AdamOptimizer(self.d_lr , beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars) |
| 171 | + |
| 172 | + """" Summary """ |
| 173 | + self.G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss) |
| 174 | + self.D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss) |
| 175 | + |
| 176 | + self.G_gan = tf.summary.scalar("G_gan", g_loss) |
| 177 | + self.G_vgg = tf.summary.scalar("G_vgg", t_loss) |
| 178 | + self.G_init_loss = tf.summary.scalar("G_init", init_loss) |
| 179 | + |
| 180 | + self.V_loss_merge = tf.summary.merge([self.G_init_loss]) |
| 181 | + self.G_loss_merge = tf.summary.merge([self.G_loss, self.G_gan, self.G_vgg, self.G_init_loss]) |
| 182 | + self.D_loss_merge = tf.summary.merge([self.D_loss]) |
| 183 | + |
| 184 | + def train(self): |
| 185 | + # initialize all variables |
| 186 | + self.sess.run(tf.global_variables_initializer()) |
| 187 | + |
| 188 | + # saver to save model |
| 189 | + self.saver = tf.train.Saver(max_to_keep=self.epoch) |
| 190 | + |
| 191 | + # summary writer |
| 192 | + self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph) |
| 193 | + |
| 194 | + """ Input Image""" |
| 195 | + real_img_op, anime_img_op, anime_smooth_op = self.real_image_generator.load_images(), self.anime_image_generator.load_images(), self.anime_smooth_generator.load_images() |
| 196 | + |
| 197 | + |
| 198 | + # restore check-point if it exits |
| 199 | + could_load, checkpoint_counter = self.load(self.checkpoint_dir) |
| 200 | + if could_load: |
| 201 | + start_epoch = checkpoint_counter + 1 |
| 202 | + |
| 203 | + print(" [*] Load SUCCESS") |
| 204 | + else: |
| 205 | + start_epoch = 0 |
| 206 | + |
| 207 | + print(" [!] Load failed...") |
| 208 | + |
| 209 | + # loop for epoch |
| 210 | + init_mean_loss = [] |
| 211 | + mean_loss = [] |
| 212 | + # training times , G : D = self.training_rate : 1 |
| 213 | + j = self.training_rate |
| 214 | + for epoch in range(start_epoch, self.epoch): |
| 215 | + for idx in range(int(self.dataset_num / self.batch_size)): |
| 216 | + anime, anime_smooth, real = self.sess.run([anime_img_op, anime_smooth_op, real_img_op]) |
| 217 | + train_feed_dict = { |
| 218 | + self.real:real[0], |
| 219 | + self.anime:anime[0], |
| 220 | + self.anime_gray:anime[1], |
| 221 | + self.anime_smooth:anime_smooth[1] |
| 222 | + } |
| 223 | + |
| 224 | + if epoch < self.init_epoch : |
| 225 | + # Init G |
| 226 | + start_time = time.time() |
| 227 | + |
| 228 | + real_images, generator_images, _, v_loss, summary_str = self.sess.run([self.real, self.generated, |
| 229 | + self.init_optim, |
| 230 | + self.init_loss, self.V_loss_merge], feed_dict = train_feed_dict) |
| 231 | + self.writer.add_summary(summary_str, epoch) |
| 232 | + init_mean_loss.append(v_loss) |
| 233 | + |
| 234 | + print("Epoch: %3d Step: %5d / %5d time: %f s init_v_loss: %.8f mean_v_loss: %.8f" % (epoch, idx,int(self.dataset_num / self.batch_size), time.time() - start_time, v_loss, np.mean(init_mean_loss))) |
| 235 | + if (idx+1)%200 ==0: |
| 236 | + init_mean_loss.clear() |
| 237 | + else : |
| 238 | + start_time = time.time() |
| 239 | + |
| 240 | + if j == self.training_rate: |
| 241 | + # Update D |
| 242 | + _, d_loss, summary_str = self.sess.run([self.D_optim, self.Discriminator_loss, self.D_loss_merge], |
| 243 | + feed_dict=train_feed_dict) |
| 244 | + self.writer.add_summary(summary_str, epoch) |
| 245 | + |
| 246 | + # Update G |
| 247 | + real_images, generator_images, _, g_loss, summary_str = self.sess.run([self.real, self.generated,self.G_optim, |
| 248 | + self.Generator_loss, self.G_loss_merge], feed_dict = train_feed_dict) |
| 249 | + self.writer.add_summary(summary_str, epoch) |
| 250 | + |
| 251 | + mean_loss.append([d_loss, g_loss]) |
| 252 | + if j == self.training_rate: |
| 253 | + |
| 254 | + print( |
| 255 | + "Epoch: %3d Step: %5d / %5d time: %f s d_loss: %.8f, g_loss: %.8f -- mean_d_loss: %.8f, mean_g_loss: %.8f" % ( |
| 256 | + epoch, idx, int(self.dataset_num / self.batch_size), time.time() - start_time, d_loss, g_loss, np.mean(mean_loss, axis=0)[0], |
| 257 | + np.mean(mean_loss, axis=0)[1])) |
| 258 | + else: |
| 259 | + print( |
| 260 | + "Epoch: %3d Step: %5d / %5d time: %f s , g_loss: %.8f -- mean_g_loss: %.8f" % ( |
| 261 | + epoch, idx, int(self.dataset_num / self.batch_size), time.time() - start_time, g_loss, np.mean(mean_loss, axis=0)[1])) |
| 262 | + |
| 263 | + if (idx + 1) % 200 == 0: |
| 264 | + mean_loss.clear() |
| 265 | + |
| 266 | + j = j - 1 |
| 267 | + if j < 1: |
| 268 | + j = self.training_rate |
| 269 | + |
| 270 | + |
| 271 | + if (epoch + 1) >= self.init_epoch and np.mod(epoch + 1, self.save_freq) == 0: |
| 272 | + self.save(self.checkpoint_dir, epoch) |
| 273 | + |
| 274 | + if epoch >= self.init_epoch -1: |
| 275 | + """ Result Image """ |
| 276 | + val_files = glob('./dataset/{}/*.*'.format('val')) |
| 277 | + save_path = './{}/{:03d}/'.format(self.sample_dir, epoch) |
| 278 | + check_folder(save_path) |
| 279 | + for i, sample_file in enumerate(val_files): |
| 280 | + print('val: '+ str(i) + sample_file) |
| 281 | + sample_image = np.asarray(load_test_data(sample_file, self.img_size)) |
| 282 | + test_real,test_generated = self.sess.run([self.test_real,self.test_generated],feed_dict = {self.test_real:sample_image} ) |
| 283 | + save_images(test_real, save_path+'{:03d}_a.jpg'.format(i), None) |
| 284 | + save_images(test_generated, save_path+'{:03d}_b.jpg'.format(i), None) |
| 285 | + |
| 286 | + @property |
| 287 | + def model_dir(self): |
| 288 | + return "{}_{}_{}_{}_{}_{}_{}_{}_{}".format(self.model_name, self.dataset_name, |
| 289 | + self.gan_type, |
| 290 | + int(self.g_adv_weight), int(self.d_adv_weight), |
| 291 | + int(self.con_weight), int(self.sty_weight), |
| 292 | + int(self.color_weight), int(self.tv_weight)) |
| 293 | + |
| 294 | + |
| 295 | + def save(self, checkpoint_dir, step): |
| 296 | + checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) |
| 297 | + if not os.path.exists(checkpoint_dir): |
| 298 | + os.makedirs(checkpoint_dir) |
| 299 | + self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step) |
| 300 | + |
| 301 | + def load(self, checkpoint_dir): |
| 302 | + print(" [*] Reading checkpoints...") |
| 303 | + checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) |
| 304 | + |
| 305 | + ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # checkpoint file information |
| 306 | + |
| 307 | + if ckpt and ckpt.model_checkpoint_path: |
| 308 | + ckpt_name = os.path.basename(ckpt.model_checkpoint_path) # first line |
| 309 | + self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) |
| 310 | + counter = int(ckpt_name.split('-')[-1]) |
| 311 | + print(" [*] Success to read {}".format(os.path.join(checkpoint_dir, ckpt_name))) |
| 312 | + return True, counter |
| 313 | + else: |
| 314 | + print(" [*] Failed to find a checkpoint") |
| 315 | + return False, 0 |
0 commit comments