-
Notifications
You must be signed in to change notification settings - Fork 232
/
generate.py
100 lines (73 loc) · 2.89 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import argparse
import math
import torch
from torchvision import utils
from model import StyledGenerator
@torch.no_grad()
def get_mean_style(generator, device):
mean_style = None
for i in range(10):
style = generator.mean_style(torch.randn(1024, 512).to(device))
if mean_style is None:
mean_style = style
else:
mean_style += style
mean_style /= 10
return mean_style
@torch.no_grad()
def sample(generator, step, mean_style, n_sample, device):
image = generator(
torch.randn(n_sample, 512).to(device),
step=step,
alpha=1,
mean_style=mean_style,
style_weight=0.7,
)
return image
@torch.no_grad()
def style_mixing(generator, step, mean_style, n_source, n_target, device):
source_code = torch.randn(n_source, 512).to(device)
target_code = torch.randn(n_target, 512).to(device)
shape = 4 * 2 ** step
alpha = 1
images = [torch.ones(1, 3, shape, shape).to(device) * -1]
source_image = generator(
source_code, step=step, alpha=alpha, mean_style=mean_style, style_weight=0.7
)
target_image = generator(
target_code, step=step, alpha=alpha, mean_style=mean_style, style_weight=0.7
)
images.append(source_image)
for i in range(n_target):
image = generator(
[target_code[i].unsqueeze(0).repeat(n_source, 1), source_code],
step=step,
alpha=alpha,
mean_style=mean_style,
style_weight=0.7,
mixing_range=(0, 1),
)
images.append(target_image[i].unsqueeze(0))
images.append(image)
images = torch.cat(images, 0)
return images
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--size', type=int, default=1024, help='size of the image')
parser.add_argument('--n_row', type=int, default=3, help='number of rows of sample matrix')
parser.add_argument('--n_col', type=int, default=5, help='number of columns of sample matrix')
parser.add_argument('path', type=str, help='path to checkpoint file')
args = parser.parse_args()
device = 'cuda'
generator = StyledGenerator(512).to(device)
generator.load_state_dict(torch.load(args.path)['g_running'])
generator.eval()
mean_style = get_mean_style(generator, device)
step = int(math.log(args.size, 2)) - 2
img = sample(generator, step, mean_style, args.n_row * args.n_col, device)
utils.save_image(img, 'sample.png', nrow=args.n_col, normalize=True, range=(-1, 1))
for j in range(20):
img = style_mixing(generator, step, mean_style, args.n_col, args.n_row, device)
utils.save_image(
img, f'sample_mixing_{j}.png', nrow=args.n_col + 1, normalize=True, range=(-1, 1)
)