-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathmain.py
220 lines (176 loc) · 6.8 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
#!/usr/bin/env python
# @Time : 2021/3/8 15:06
# @Author : wb
# @File : main.py
'''
主文件,用于训练,测试等
'''
import torch
from torch.utils.data import DataLoader
from torchnet import meter
from tqdm import tqdm
import os
import models
from config import opt
from utils.visualize import Visualizer
from data.dataset import CWRUDataset1D, CWRUDataset2D
def train(**kwargs):
'''
训练
:param kwargs: 可调整参数,默认是config中的默认参数
:return:训练出完整模型
'''
# 根据命令行参数更新配置
opt.parse(kwargs)
# visdom绘图程序,需要启动visdom服务器
vis = Visualizer(opt.env, port=opt.vis_port)
# step:1 构建模型
# 选取配置中名字为model的模型
model = getattr(models, opt.model)()
# 是否读取保存好的模型参数
if opt.load_model_path:
model = model.load(opt.load_model_path)
# 设置GPU
if torch.cuda.is_available():
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
model = model.to(opt.device)
# step2: 数据
train_data = CWRUDataset1D(opt.train_data_root, train=True)
# 测试数据集和验证数据集是一样的,这些数据是没有用于训练的
test_data = CWRUDataset1D(opt.train_data_root, train=False)
# 使用DataLoader一条一条读取数据
train_dataloader = DataLoader(train_data, opt.batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, opt.batch_size, shuffle=False)
# step3: 目标函数和优化器
# 损失函数,交叉熵
criterion = torch.nn.CrossEntropyLoss()
lr = opt.lr
# 优化函数,Adam
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=opt.weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, opt.lr_decay_iters,
opt.lr_decay) # regulation rate decay
# step4: 统计指标,平滑处理之后的损失,还有混淆矩阵
# 损失进行取平均及方差计算。
loss_meter = meter.AverageValueMeter()
# 混淆矩阵
confusion_matrix = meter.ConfusionMeter(opt.CWRU_category)
previous_loss = 1e10
# 训练
for epoch in range(opt.max_epoch):
# 重置
loss_meter.reset()
confusion_matrix.reset()
for ii, (data, label) in tqdm(enumerate(train_dataloader)):
# print('data', data)
# print('label', label)
# 改变形状
data.resize_(data.size()[0], 1, data.size()[1])
# 训练模型
# 转换成float
input = data.type(torch.FloatTensor).to(opt.device)
target = label.type(torch.LongTensor).to(opt.device)
optimizer.zero_grad()
score = model(input)
# 计算loss
loss = criterion(score, target)
loss.backward()
# 优化参数
optimizer.step()
# 修改学习率
scheduler.step()
# 更新统计指标以及可视化
loss_meter.add(loss.item())
# detach 一下更安全保险
confusion_matrix.add(score.detach(), target.detach())
if (ii + 1) % opt.print_freq == 0:
# vis绘图
vis.plot('loss', loss_meter.value()[0])
# 打印出信息
print('t = %d, loss = %.4f' % (ii + 1, loss.item()))
# 进入debug模式
if os.path.exists(opt.debug_file):
import ipdb;
ipdb.set_trace()
# 每个batch保存模型
model.save()
# 计算测试集上的指标和可视化
val_cm, val_accuracy = val(model, test_dataloader)
vis.plot('val_accuracy', val_accuracy)
vis.log("epoch:{epoch},lr:{lr},loss:{loss},train_cm:{train_cm},val_cm:{val_cm}".format(
epoch=epoch, loss=loss_meter.value()[0], val_cm=str(val_cm.value()), train_cm=str(confusion_matrix.value()),
lr=lr))
# 如果损失不在下降,那么就降低学习率
if loss_meter.value()[0] > previous_loss:
lr = lr * opt.lr_decay
# 第二种降低学习率的方法:不会有moment等信息的丢失
for param_group in optimizer.param_groups:
param_group['lr'] = lr
previous_loss = loss_meter.value()[0]
def val(model, dataloader):
"""
计算模型在验证集上的准确率等信息
"""
# pytorch会自动把BN和DropOut固定住,不会取平均,而是用训练好的值
model.eval()
confusion_matrix = meter.ConfusionMeter(opt.CWRU_category)
for ii, (data, label) in tqdm(enumerate(dataloader)):
# 改变形状
data.resize_(data.size()[0], 1, data.size()[1])
# 训练模型
# 转换成float
test_input = data.type(torch.FloatTensor).to(opt.device)
target = label.type(torch.LongTensor).to(opt.device)
score = model(test_input)
confusion_matrix.add(score.detach(), target)
model.train()
cm_value = confusion_matrix.value()
accuracy = 100. * (cm_value[0][0] + cm_value[1][1]) / (cm_value.sum())
return confusion_matrix, accuracy
def test(**kwargs):
opt._parse(kwargs)
# 构建模型
model = getattr(models, opt.model)().eval()
if opt.load_model_path:
model.load(opt.load_model_path)
model.to(opt.device)
# data
train_data = CWRUDataset2D(opt.train_data_root, test=True)
test_dataloader = DataLoader(train_data, batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_workers)
results = []
for ii, (data, path) in tqdm(enumerate(test_dataloader)):
input = data.to(opt.device)
score = model(input)
probability = torch.nn.functional.softmax(score, dim=1)[:, 0].detach().tolist()
# label = score.max(dim = 1)[1].detach().tolist()
batch_results = [(path_.item(), probability_) for path_, probability_ in zip(path, probability)]
results += batch_results
write_csv(results, opt.result_file)
return results
def write_csv(results, file_name):
import csv
with open(file_name, 'w') as f:
writer = csv.writer(f)
writer.writerow(['id', 'label'])
writer.writerows(results)
def build_pseudo_label():
'''
构建伪标签
:return: 伪标签集合
'''
def help():
"""
打印帮助的信息: python file.py help
"""
print("""
usage : python file.py <function> [--args=value]
<function> := train | test | help
example:
python {0} train --env='env0701' --lr=0.01
python {0} test --dataset='path/to/dataset/root/'
python {0} help
avaiable args:""".format(__file__))
from inspect import getsource
source = (getsource(opt.__class__))
# print(source)
if __name__ == '__main__':
train()