-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
62 lines (50 loc) · 1.7 KB
/
main.py
File metadata and controls
62 lines (50 loc) · 1.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# #%%
from torch import uint8
from train import *
import matplotlib.pyplot as plt
import time
from dataset import *
if param.mode == "Train":
path_weight, path_img, path_noise = get_path()
Gs, Ds = init_models(param.mode)
optimizerGs, optimizerDs = init_optimizer(Gs, Ds)
# print(Gs)
for scale_num in range(param.num_scale): # 0~8
if scale_num == 0:
pass
else:
Gs[scale_num].load_state_dict(
torch.load(path_weight + "/{}_scale/netG.pth".format(scale_num - 1))
)
Ds[scale_num].load_state_dict(
torch.load(path_weight + "/{}_scale/netD.pth".format(scale_num - 1))
)
Gs, Ds = train_single(
Gs,
Ds,
scale_num,
optimizerGs,
optimizerDs,
path_weight,
path_img,
path_noise,
)
elif param.mode == "Test":
start = time.time() # 시작 시간 저장
for imgs in test_loader:
fixed_noise = torch.load("./noise/" + param.folder_des + "/fixed_noise.pth")
fixed_shape = fixed_noise.shape
fixed_noise = fixed_noise.unsqueeze(0)
Gs, Ds = init_models(param.mode)
for scale_num in range(param.num_scale): # 0~7
Gs[scale_num].load_state_dict(
torch.load(
"./weights/"
+ param.folder_des
+ "/{}_scale/netG.pth".format(scale_num)
)
)
Gs[scale_num].eval()
for i in range(param.test_num):
test(Gs, imgs, param.num_scale, fixed_noise, i)
print("time :", time.time() - start) # 현재시각 - 시작시간 = 실행 시간