-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
128 lines (113 loc) · 3.01 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import argparse
import numpy as np
import torch
from load import load_solar_data, load_wind_data
from model import GAN
def main(args):
"""Main function"""
#Check available device
device = (
'cuda'
if torch.cuda.is_available()
else 'mps'
if torch.backends.mps.is_available()
else 'cpu')
print('Current device:', device)
#Load data and labels
trX = None
trY = None
m = None
if args.data.endswith('solar.csv'):
trX, trY, m = load_solar_data(args.data, args.label)
elif args.data.endswith('wind.csv'):
trX, trY, m = load_wind_data(args.data, args.label)
#Determine number of unique labels
events_num = len(np.unique(trY))
#Instantiate model
GAN_model = GAN(epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
dim_y=events_num,
device=device).to(device)
#Start training
GAN_model.fit(trX, trY)
#Generate samples
data_gen, labels_sampled = GAN_model.predict()
#Rescaling
data_gen = data_gen*m
return data_gen
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--data',
help='Select data set for training',
type=str
)
parser.add_argument(
'--label',
help='Select labels corresponding to data set',
type=str
)
parser.add_argument(
'--epochs',
help='Training iterations',
default=5000,
type=int
)
parser.add_argument(
'--batch_size',
help='Number of samples for one optimization',
default=32,
type=int
)
parser.add_argument(
'--learning_rate',
help='Learning rate for optimizer',
default=1e-4,
type=float
)
#Exclude the following structural parameters
# parser.add_argument(
# '--image_shape',
# help='Define image shape (channels, height, width)',
# default=[1, 24, 24],
# tpye=list
# )
# parser.add_argument(
# '--dim_y',
# help='Sets number of channels (corresponds to the number of unique labels)',
# default=6,
# type=int
# )
# parser.add_argument(
# '--dim_z',
# help='Sets number of channels for sampled noise images',
# default=100,
# type=int
# )
# parser.add_argument(
# '--dim_W1',
# help='Layer dimension parameter',
# default=1024,
# type=int
# )
# parser.add_argument(
# '--dim_W2',
# help='Layer dimension parameter',
# default=128,
# type=int
# )
# parser.add_argument(
# '--dim_W3',
# help='Layer dimension parameter',
# default=64,
# type=int
# )
# parser.add_argument(
# '--dim_channel',
# help='Output dimension channels',
# default=1,
# type=int
# )
args = parser.parse_args()
data_gen = main(args)