Skip to content

Commit 7e3d099

Browse files
authored
Merge pull request #1558 from RB137/fix-issue-1545
Added AI Anime-Avatar Generator model with (README, requirements, LICENSE) files
2 parents a37ec8f + 3381c33 commit 7e3d099

File tree

9 files changed

+588
-0
lines changed

9 files changed

+588
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
from tools.ops import *
2+
from tools.utils import *
3+
from glob import glob
4+
import time
5+
import numpy as np
6+
from net import generator
7+
from net.discriminator import D_net
8+
from tools.data_loader import ImageGenerator
9+
from tools.vgg19 import Vgg19
10+
11+
class AnimeGANv2(object) :
12+
def __init__(self, sess, args):
13+
self.model_name = 'AnimeGANv2'
14+
self.sess = sess
15+
self.checkpoint_dir = args.checkpoint_dir
16+
self.log_dir = args.log_dir
17+
self.dataset_name = args.dataset
18+
19+
self.epoch = args.epoch
20+
self.init_epoch = args.init_epoch # args.epoch // 20
21+
22+
self.gan_type = args.gan_type
23+
self.batch_size = args.batch_size
24+
self.save_freq = args.save_freq
25+
26+
self.init_lr = args.init_lr
27+
self.d_lr = args.d_lr
28+
self.g_lr = args.g_lr
29+
30+
""" Weight """
31+
self.g_adv_weight = args.g_adv_weight
32+
self.d_adv_weight = args.d_adv_weight
33+
self.con_weight = args.con_weight
34+
self.sty_weight = args.sty_weight
35+
self.color_weight = args.color_weight
36+
self.tv_weight = args.tv_weight
37+
38+
self.training_rate = args.training_rate
39+
self.ld = args.ld
40+
41+
self.img_size = args.img_size
42+
self.img_ch = args.img_ch
43+
44+
""" Discriminator """
45+
self.n_dis = args.n_dis
46+
self.ch = args.ch
47+
self.sn = args.sn
48+
49+
self.sample_dir = os.path.join(args.sample_dir, self.model_dir)
50+
check_folder(self.sample_dir)
51+
52+
self.real = tf.placeholder(tf.float32, [self.batch_size, self.img_size[0], self.img_size[1], self.img_ch], name='real_A')
53+
self.anime = tf.placeholder(tf.float32, [self.batch_size, self.img_size[0], self.img_size[1], self.img_ch], name='anime_A')
54+
self.anime_smooth = tf.placeholder(tf.float32, [self.batch_size, self.img_size[0], self.img_size[1], self.img_ch], name='anime_smooth_A')
55+
self.test_real = tf.placeholder(tf.float32, [1, None, None, self.img_ch], name='test_input')
56+
57+
self.anime_gray = tf.placeholder(tf.float32, [self.batch_size, self.img_size[0], self.img_size[1], self.img_ch],name='anime_B')
58+
59+
60+
self.real_image_generator = ImageGenerator('./dataset/train_photo', self.img_size, self.batch_size)
61+
self.anime_image_generator = ImageGenerator('./dataset/{}'.format(self.dataset_name + '/style'), self.img_size, self.batch_size)
62+
self.anime_smooth_generator = ImageGenerator('./dataset/{}'.format(self.dataset_name + '/smooth'), self.img_size, self.batch_size)
63+
self.dataset_num = max(self.real_image_generator.num_images, self.anime_image_generator.num_images)
64+
65+
self.vgg = Vgg19()
66+
67+
print()
68+
print("##### Information #####")
69+
print("# gan type : ", self.gan_type)
70+
print("# dataset : ", self.dataset_name)
71+
print("# max dataset number : ", self.dataset_num)
72+
print("# batch_size : ", self.batch_size)
73+
print("# epoch : ", self.epoch)
74+
print("# init_epoch : ", self.init_epoch)
75+
print("# training image size [H, W] : ", self.img_size)
76+
print("# g_adv_weight,d_adv_weight,con_weight,sty_weight,color_weight,tv_weight : ", self.g_adv_weight,self.d_adv_weight,self.con_weight,self.sty_weight,self.color_weight,self.tv_weight)
77+
print("# init_lr,g_lr,d_lr : ", self.init_lr,self.g_lr,self.d_lr)
78+
print(f"# training_rate G -- D: {self.training_rate} : 1" )
79+
print()
80+
81+
##################################################################################
82+
# Generator
83+
##################################################################################
84+
85+
def generator(self, x_init, reuse=False, scope="generator"):
86+
with tf.variable_scope(scope, reuse=reuse):
87+
G = generator.G_net(x_init)
88+
return G.fake
89+
90+
##################################################################################
91+
# Discriminator
92+
##################################################################################
93+
94+
def discriminator(self, x_init, reuse=False, scope="discriminator"):
95+
D = D_net(x_init, self.ch, self.n_dis, self.sn, reuse=reuse, scope=scope)
96+
return D
97+
98+
##################################################################################
99+
# Model
100+
##################################################################################
101+
def gradient_panalty(self, real, fake, scope="discriminator"):
102+
if self.gan_type.__contains__('dragan') :
103+
eps = tf.random_uniform(shape=tf.shape(real), minval=0., maxval=1.)
104+
_, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3])
105+
x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region
106+
107+
fake = real + 0.5 * x_std * eps
108+
109+
alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.)
110+
interpolated = real + alpha * (fake - real)
111+
112+
logit, _= self.discriminator(interpolated, reuse=True, scope=scope)
113+
114+
grad = tf.gradients(logit, interpolated)[0] # gradient of D(interpolated)
115+
grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm
116+
117+
GP = 0
118+
# WGAN - LP
119+
if self.gan_type.__contains__('lp'):
120+
GP = self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.)))
121+
122+
elif self.gan_type.__contains__('gp') or self.gan_type == 'dragan' :
123+
GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.))
124+
125+
return GP
126+
127+
def build_model(self):
128+
129+
""" Define Generator, Discriminator """
130+
self.generated = self.generator(self.real)
131+
self.test_generated = self.generator(self.test_real, reuse=True)
132+
133+
134+
anime_logit = self.discriminator(self.anime)
135+
anime_gray_logit = self.discriminator(self.anime_gray, reuse=True)
136+
137+
generated_logit = self.discriminator(self.generated, reuse=True)
138+
smooth_logit = self.discriminator(self.anime_smooth, reuse=True)
139+
140+
""" Define Loss """
141+
if self.gan_type.__contains__('gp') or self.gan_type.__contains__('lp') or self.gan_type.__contains__('dragan') :
142+
GP = self.gradient_panalty(real=self.anime, fake=self.generated)
143+
else :
144+
GP = 0.0
145+
146+
# init pharse
147+
init_c_loss = con_loss(self.vgg, self.real, self.generated)
148+
init_loss = self.con_weight * init_c_loss
149+
150+
self.init_loss = init_loss
151+
152+
# gan
153+
c_loss, s_loss = con_sty_loss(self.vgg, self.real, self.anime_gray, self.generated)
154+
tv_loss = self.tv_weight * total_variation_loss(self.generated)
155+
t_loss = self.con_weight * c_loss + self.sty_weight * s_loss + color_loss(self.real,self.generated) * self.color_weight + tv_loss
156+
157+
g_loss = self.g_adv_weight * generator_loss(self.gan_type, generated_logit)
158+
d_loss = self.d_adv_weight * discriminator_loss(self.gan_type, anime_logit, anime_gray_logit, generated_logit, smooth_logit) + GP
159+
160+
self.Generator_loss = t_loss + g_loss
161+
self.Discriminator_loss = d_loss
162+
163+
""" Training """
164+
t_vars = tf.trainable_variables()
165+
G_vars = [var for var in t_vars if 'generator' in var.name]
166+
D_vars = [var for var in t_vars if 'discriminator' in var.name]
167+
168+
self.init_optim = tf.train.AdamOptimizer(self.init_lr, beta1=0.5, beta2=0.999).minimize(self.init_loss, var_list=G_vars)
169+
self.G_optim = tf.train.AdamOptimizer(self.g_lr , beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars)
170+
self.D_optim = tf.train.AdamOptimizer(self.d_lr , beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars)
171+
172+
"""" Summary """
173+
self.G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss)
174+
self.D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss)
175+
176+
self.G_gan = tf.summary.scalar("G_gan", g_loss)
177+
self.G_vgg = tf.summary.scalar("G_vgg", t_loss)
178+
self.G_init_loss = tf.summary.scalar("G_init", init_loss)
179+
180+
self.V_loss_merge = tf.summary.merge([self.G_init_loss])
181+
self.G_loss_merge = tf.summary.merge([self.G_loss, self.G_gan, self.G_vgg, self.G_init_loss])
182+
self.D_loss_merge = tf.summary.merge([self.D_loss])
183+
184+
def train(self):
185+
# initialize all variables
186+
self.sess.run(tf.global_variables_initializer())
187+
188+
# saver to save model
189+
self.saver = tf.train.Saver(max_to_keep=self.epoch)
190+
191+
# summary writer
192+
self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph)
193+
194+
""" Input Image"""
195+
real_img_op, anime_img_op, anime_smooth_op = self.real_image_generator.load_images(), self.anime_image_generator.load_images(), self.anime_smooth_generator.load_images()
196+
197+
198+
# restore check-point if it exits
199+
could_load, checkpoint_counter = self.load(self.checkpoint_dir)
200+
if could_load:
201+
start_epoch = checkpoint_counter + 1
202+
203+
print(" [*] Load SUCCESS")
204+
else:
205+
start_epoch = 0
206+
207+
print(" [!] Load failed...")
208+
209+
# loop for epoch
210+
init_mean_loss = []
211+
mean_loss = []
212+
# training times , G : D = self.training_rate : 1
213+
j = self.training_rate
214+
for epoch in range(start_epoch, self.epoch):
215+
for idx in range(int(self.dataset_num / self.batch_size)):
216+
anime, anime_smooth, real = self.sess.run([anime_img_op, anime_smooth_op, real_img_op])
217+
train_feed_dict = {
218+
self.real:real[0],
219+
self.anime:anime[0],
220+
self.anime_gray:anime[1],
221+
self.anime_smooth:anime_smooth[1]
222+
}
223+
224+
if epoch < self.init_epoch :
225+
# Init G
226+
start_time = time.time()
227+
228+
real_images, generator_images, _, v_loss, summary_str = self.sess.run([self.real, self.generated,
229+
self.init_optim,
230+
self.init_loss, self.V_loss_merge], feed_dict = train_feed_dict)
231+
self.writer.add_summary(summary_str, epoch)
232+
init_mean_loss.append(v_loss)
233+
234+
print("Epoch: %3d Step: %5d / %5d time: %f s init_v_loss: %.8f mean_v_loss: %.8f" % (epoch, idx,int(self.dataset_num / self.batch_size), time.time() - start_time, v_loss, np.mean(init_mean_loss)))
235+
if (idx+1)%200 ==0:
236+
init_mean_loss.clear()
237+
else :
238+
start_time = time.time()
239+
240+
if j == self.training_rate:
241+
# Update D
242+
_, d_loss, summary_str = self.sess.run([self.D_optim, self.Discriminator_loss, self.D_loss_merge],
243+
feed_dict=train_feed_dict)
244+
self.writer.add_summary(summary_str, epoch)
245+
246+
# Update G
247+
real_images, generator_images, _, g_loss, summary_str = self.sess.run([self.real, self.generated,self.G_optim,
248+
self.Generator_loss, self.G_loss_merge], feed_dict = train_feed_dict)
249+
self.writer.add_summary(summary_str, epoch)
250+
251+
mean_loss.append([d_loss, g_loss])
252+
if j == self.training_rate:
253+
254+
print(
255+
"Epoch: %3d Step: %5d / %5d time: %f s d_loss: %.8f, g_loss: %.8f -- mean_d_loss: %.8f, mean_g_loss: %.8f" % (
256+
epoch, idx, int(self.dataset_num / self.batch_size), time.time() - start_time, d_loss, g_loss, np.mean(mean_loss, axis=0)[0],
257+
np.mean(mean_loss, axis=0)[1]))
258+
else:
259+
print(
260+
"Epoch: %3d Step: %5d / %5d time: %f s , g_loss: %.8f -- mean_g_loss: %.8f" % (
261+
epoch, idx, int(self.dataset_num / self.batch_size), time.time() - start_time, g_loss, np.mean(mean_loss, axis=0)[1]))
262+
263+
if (idx + 1) % 200 == 0:
264+
mean_loss.clear()
265+
266+
j = j - 1
267+
if j < 1:
268+
j = self.training_rate
269+
270+
271+
if (epoch + 1) >= self.init_epoch and np.mod(epoch + 1, self.save_freq) == 0:
272+
self.save(self.checkpoint_dir, epoch)
273+
274+
if epoch >= self.init_epoch -1:
275+
""" Result Image """
276+
val_files = glob('./dataset/{}/*.*'.format('val'))
277+
save_path = './{}/{:03d}/'.format(self.sample_dir, epoch)
278+
check_folder(save_path)
279+
for i, sample_file in enumerate(val_files):
280+
print('val: '+ str(i) + sample_file)
281+
sample_image = np.asarray(load_test_data(sample_file, self.img_size))
282+
test_real,test_generated = self.sess.run([self.test_real,self.test_generated],feed_dict = {self.test_real:sample_image} )
283+
save_images(test_real, save_path+'{:03d}_a.jpg'.format(i), None)
284+
save_images(test_generated, save_path+'{:03d}_b.jpg'.format(i), None)
285+
286+
@property
287+
def model_dir(self):
288+
return "{}_{}_{}_{}_{}_{}_{}_{}_{}".format(self.model_name, self.dataset_name,
289+
self.gan_type,
290+
int(self.g_adv_weight), int(self.d_adv_weight),
291+
int(self.con_weight), int(self.sty_weight),
292+
int(self.color_weight), int(self.tv_weight))
293+
294+
295+
def save(self, checkpoint_dir, step):
296+
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
297+
if not os.path.exists(checkpoint_dir):
298+
os.makedirs(checkpoint_dir)
299+
self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step)
300+
301+
def load(self, checkpoint_dir):
302+
print(" [*] Reading checkpoints...")
303+
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
304+
305+
ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # checkpoint file information
306+
307+
if ckpt and ckpt.model_checkpoint_path:
308+
ckpt_name = os.path.basename(ckpt.model_checkpoint_path) # first line
309+
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
310+
counter = int(ckpt_name.split('-')[-1])
311+
print(" [*] Success to read {}".format(os.path.join(checkpoint_dir, ckpt_name)))
312+
return True, counter
313+
else:
314+
print(" [*] Failed to find a checkpoint")
315+
return False, 0
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2024 RAMESWAR BISOYI
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

0 commit comments

Comments
 (0)