-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
65 lines (44 loc) · 1.69 KB
/
main.py
File metadata and controls
65 lines (44 loc) · 1.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from src.parsing import Options
from src.train import train_clf, train_sr, train_sr_gan, train_imp, train_imp_gan
from src.evaluate import evaluate_clf, evaluate_ecg_sr, evaluate_ecg_imp, evaluate_shl_imp, evaluate_shl_sr,\
evaluate_audio_imp,evaluate_audio_sr, evaluate_pam2_imp, evaluate_pam2_sr
import os
if __name__ == "__main__":
opt = Options().parse()
if opt.decay_half:
opt.decay_every = opt.epochs//2
else:
opt.decay_every = opt.epochs+1 #This will ensure that the learning rate does not change
if opt.save_dir:
os.makedirs(opt.save_dir,exist_ok=True)
if opt.evaluate:
if opt.model_type == 'clf':
evaluate_clf(opt)
elif opt.model_type in ['sr','sr_gan']:
if opt.data_type=='ecg':
evaluate_ecg_sr(opt)
elif opt.data_type=='shl':
evaluate_shl_sr(opt)
elif opt.data_type=='audio':
evaluate_audio_sr(opt)
elif opt.data_type=='pam2':
evaluate_pam2_sr(opt)
elif opt.model_type in ['imp','imp_gan']:
if opt.data_type=='ecg':
evaluate_ecg_imp(opt)
elif opt.data_type=='shl':
evaluate_shl_imp(opt)
elif opt.data_type=='audio':
evaluate_audio_imp(opt)
elif opt.model_type=='pam2':
evaluate_pam2_imp(opt)
elif opt.model_type == 'clf':
train_clf(opt)
elif opt.model_type == 'sr':
train_sr(opt)
elif opt.model_type == 'sr_gan':
train_sr_gan(opt)
elif opt.model_type == 'imp':
train_imp(opt)
elif opt.model_type == 'imp_gan':
train_imp_gan(opt)