1
+ import os
2
+ import sys
3
+ sys .path .append (os .path .join (os .path .dirname (__file__ )))
4
+
5
+ import yaml
6
+ import logging
7
+ import torch
8
+ import torch .optim as optim
9
+ import torch .optim .lr_scheduler as lr_scheduler
10
+ from torch .cuda import amp
11
+ from tools .datasets import create_dataloader , preprocess
12
+ from tqdm import tqdm
13
+ import math
14
+ import numpy as np
15
+ import time
16
+ import evaluate
17
+
18
+ from tools .general import (set_logging ,
19
+ init_seeds ,
20
+ check_dataset ,
21
+ check_img_size ,
22
+ torch_distributed_zero_first ,
23
+ plot_labels ,
24
+ labels_to_class_weights ,
25
+ compute_loss ,
26
+ plot_images ,
27
+ fitness ,
28
+ check_anchors
29
+ )
30
+ from tools .torch_utils import select_device , ModelEMA
31
+
32
+ logger = logging .getLogger (__name__ )
33
+
34
+ class obj (object ):
35
+ def __init__ (self , d ):
36
+ for a , b in d .items ():
37
+ if isinstance (b , (list , tuple )):
38
+ setattr (self , a , [obj (x ) if isinstance (x , dict ) else x for x in b ])
39
+ else :
40
+ setattr (self , a , obj (b ) if isinstance (b , dict ) else b )
41
+
42
+ def train (modelWrapper , data , hyp , opt , device ):
43
+ model = modelWrapper .model
44
+ ckpt = modelWrapper .config ['ckpt' ]
45
+ logger .info (f'Hyperparameters { hyp } ' )
46
+ log_dir = './evolve'
47
+ wdir = log_dir + '/weights'
48
+ os .makedirs (wdir , exist_ok = True )
49
+ last = wdir + '/last.pt'
50
+ best = wdir + '/best.pt'
51
+ results_file = log_dir + '/results.txt'
52
+
53
+ epochs , batch_size , total_batch_size , weights , rank = \
54
+ opt .epochs , opt .batch_size , opt .total_batch_size , opt .weights , opt .global_rank
55
+
56
+ with open (log_dir + '/hyp-train.yaml' , 'w' ) as f :
57
+ yaml .dump (hyp , f , sort_keys = False )
58
+ with open (log_dir + '/opt-train.yaml' , 'w' ) as f :
59
+ yaml .dump (vars (opt ), f , sort_keys = False )
60
+
61
+ # Configure
62
+ cuda = device .type != 'cpu'
63
+ init_seeds (2 + rank )
64
+
65
+ print ('.......' , opt .data )
66
+ with open (opt .data ) as f :
67
+ data_dict = yaml .load (f , Loader = yaml .FullLoader )
68
+ with torch_distributed_zero_first (rank ):
69
+ check_dataset (data_dict )
70
+ train_path = data_dict ['train' ]
71
+ test_path = data_dict ['val' ]
72
+ nc , names = (int (data_dict ['nc' ]), data_dict ['names' ])
73
+ assert len (names ) == nc , '%g names found for nc=%g dataset in %s' % (len (names ), nc , opt .data )
74
+
75
+ # Optimizer
76
+ nbs = 64
77
+ accumulate = max (round (nbs / total_batch_size ), 1 )
78
+ hyp ['weight_decay' ] *= total_batch_size * accumulate / nbs
79
+
80
+ pg0 , pg1 , pg2 = [], [], []
81
+ for k , v in model .named_parameters ():
82
+ v .requires_grad = True
83
+ if '.bias' in k :
84
+ pg2 .append (v )
85
+ elif '.weight' in k and '.bn' not in k :
86
+ pg1 .append (v )
87
+ else :
88
+ pg0 .append (v )
89
+
90
+ optimizer = optim .SGD (pg0 , lr = hyp ['lr0' ], momentum = hyp ['momentum' ], nesterov = True )
91
+
92
+ optimizer .add_param_group ({'params' : pg1 , 'weight_decay' : hyp ['weight_decay' ]})
93
+ optimizer .add_param_group ({'params' : pg2 })
94
+ logger .info ('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len (pg2 ), len (pg1 ), len (pg0 )))
95
+ del pg0 , pg1 , pg2
96
+
97
+ lf = lambda x : ((1 + math .cos (x * math .pi / epochs )) / 2 ) * (1 - hyp ['lrf' ]) + hyp ['lrf' ] # cosine
98
+ scheduler = lr_scheduler .LambdaLR (optimizer , lr_lambda = lf )
99
+
100
+ start_epoch , best_fitness = 0 , 0.0
101
+ # Optimizer
102
+ if ckpt ['optimizer' ] is not None :
103
+ optimizer .load_state_dict (ckpt ['optimizer' ])
104
+ best_fitness = ckpt ['best_fitness' ]
105
+
106
+ # Results
107
+ if ckpt .get ('training_results' ) is not None :
108
+ with open (results_file , 'w' ) as file :
109
+ file .write (ckpt ['training_results' ])
110
+
111
+ # Epochs
112
+ start_epoch = ckpt ['epoch' ] + 1
113
+ if epochs < start_epoch :
114
+ logger .info ('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
115
+ (weights , ckpt ['epoch' ], epochs ))
116
+ epochs += ckpt ['epoch' ]
117
+
118
+ del ckpt
119
+
120
+ # Image sizes
121
+ gs = int (max (model .stride ))
122
+ imgsz , imgsz_test = [check_img_size (x , gs ) for x in opt .img_size ]
123
+
124
+ # DP mode
125
+ if cuda and torch .cuda .device_count () > 1 :
126
+ model = torch .nn .DataParallel (model )
127
+
128
+ # Exponential moving average
129
+ ema = ModelEMA (model )
130
+
131
+ dataloader , dataset = create_dataloader (train_path , imgsz , batch_size , gs , opt ,
132
+ hyp = hyp , augment = True )
133
+ mlc = np .concatenate (dataset .labels , 0 )[:, 0 ].max ()
134
+ nb = len (dataloader )
135
+ assert mlc < nc , 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc , nc , opt .data , nc - 1 )
136
+
137
+ ema .updates = start_epoch * nb // accumulate
138
+
139
+ labels = np .concatenate (dataset .labels , 0 )
140
+ c = torch .tensor (labels [:, 0 ])
141
+ plot_labels (labels , save_dir = log_dir )
142
+ check_anchors (dataset , model = model , thr = hyp ['anchor_t' ], imgsz = imgsz )
143
+
144
+ # Model parameters
145
+ hyp ['cls' ] *= nc / 80.
146
+ model .nc = nc
147
+ model .hyp = hyp
148
+ model .gr = 1.0
149
+ model .class_weights = labels_to_class_weights (dataset .labels , nc ).to (device )
150
+ model .names = names
151
+
152
+ # Start training
153
+ t0 = time .time ()
154
+ nw = max (round (hyp ['warmup_epochs' ] * nb ), 1e3 )
155
+ maps = np .zeros (nc ) # mAP per class
156
+ results = (
0 ,
0 ,
0 ,
0 ,
0 ,
0 ,
0 )
# P, R, [email protected] , [email protected] , val_loss(box, obj, cls)
157
+ scheduler .last_epoch = start_epoch - 1 # do not move
158
+ scaler = amp .GradScaler (enabled = cuda )
159
+ logger .info ('Image sizes %g train, %g test\n '
160
+ 'Using %g dataloader workers\n Logging results to %s\n '
161
+ 'Starting training for %g epochs...' % (imgsz , imgsz_test , dataloader .num_workers , log_dir , epochs ))
162
+
163
+ for epoch in range (start_epoch , epochs ):
164
+ model .train ()
165
+
166
+ mloss = torch .zeros (4 , device = device ) # mean losses
167
+ pbar = enumerate (dataloader )
168
+ logger .info (('\n ' + '%10s' * 8 ) % ('Epoch' , 'gpu_mem' , 'box' , 'obj' , 'cls' , 'total' , 'targets' , 'img_size' ))
169
+ pbar = tqdm (pbar , total = nb ) # progress bar
170
+ optimizer .zero_grad ()
171
+ for i , (imgs , targets , paths , _ ) in pbar :
172
+ ni = i + nb * epoch # number integrated batches (since train start)
173
+ imgs = imgs .to (device , non_blocking = True ).float () / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
174
+
175
+ # Warmup
176
+ if ni <= nw :
177
+ xi = [0 , nw ] # x interp
178
+ # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
179
+ accumulate = max (1 , np .interp (ni , xi , [1 , nbs / total_batch_size ]).round ())
180
+ for j , x in enumerate (optimizer .param_groups ):
181
+ # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
182
+ x ['lr' ] = np .interp (ni , xi , [hyp ['warmup_bias_lr' ] if j == 2 else 0.0 , x ['initial_lr' ] * lf (epoch )])
183
+ if 'momentum' in x :
184
+ x ['momentum' ] = np .interp (ni , xi , [hyp ['warmup_momentum' ], hyp ['momentum' ]])
185
+
186
+ # Forward
187
+ with amp .autocast (enabled = cuda ):
188
+ pred = model (imgs ) # forward
189
+ loss , loss_items = compute_loss (pred , targets .to (device ), model ) # loss scaled by batch_size
190
+ if rank != - 1 :
191
+ loss *= opt .world_size # gradient averaged between devices in DDP mode
192
+
193
+ # Backward
194
+ scaler .scale (loss ).backward ()
195
+
196
+ # Optimize
197
+ if ni % accumulate == 0 :
198
+ scaler .step (optimizer ) # optimizer.step
199
+ scaler .update ()
200
+ optimizer .zero_grad ()
201
+ if ema :
202
+ ema .update (model )
203
+
204
+ # Print
205
+ mloss = (mloss * i + loss_items ) / (i + 1 ) # update mean losses
206
+ mem = '%.3gG' % (torch .cuda .memory_reserved () / 1E9 if torch .cuda .is_available () else 0 ) # (GB)
207
+ s = ('%10s' * 2 + '%10.4g' * 6 ) % (
208
+ '%g/%g' % (epoch , epochs - 1 ), mem , * mloss , targets .shape [0 ], imgs .shape [- 1 ])
209
+ pbar .set_description (s )
210
+
211
+ # Plot
212
+ if ni < 3 :
213
+ f = str (('log_dir/train_batch%g.jpg' % ni )) # filename
214
+ result = plot_images (images = imgs , targets = targets , paths = paths , fname = f )
215
+
216
+
217
+ # end batch ------------------------------------------------------------------------------------------------
218
+
219
+ # Scheduler
220
+ lr = [x ['lr' ] for x in optimizer .param_groups ] # for tensorboard
221
+ scheduler .step ()
222
+
223
+ # DDP process 0 or single-GPU
224
+ # mAP
225
+ if ema :
226
+ ema .update_attr (model , include = ['yaml' , 'nc' , 'hyp' , 'gr' , 'names' , 'stride' ])
227
+ final_epoch = epoch + 1 == epochs
228
+ results , maps , times = evaluate .test (opt .data ,
229
+ batch_size = total_batch_size ,
230
+ imgsz = imgsz_test ,
231
+ model = ema .ema ,
232
+ single_cls = opt .single_cls ,
233
+ dataloader = dataloader ,
234
+ save_dir = log_dir ,
235
+ plots = epoch == 0 or final_epoch )
236
+ # Write
237
+ with open (results_file , 'a' ) as f :
238
+ f .
write (
s + '%10.4g' * 7 % results + '\n ' )
# P, R, [email protected] , [email protected] , val_loss(box, obj, cls)
239
+
240
+ # Update best mAP
241
+ fi = fitness (
np .
array (
results ).
reshape (
1 ,
- 1 ))
# weighted combination of [P, R, [email protected] , [email protected] ]
242
+ if fi > best_fitness :
243
+ best_fitness = fi
244
+ print ('----best map' , fi )
245
+
246
+ # Save model
247
+ with open (results_file , 'r' ) as f : # create checkpoint
248
+ ckpt = {'epoch' : epoch ,
249
+ 'best_fitness' : best_fitness ,
250
+ 'training_results' : f .read (),
251
+ 'model' : ema .ema ,
252
+ 'optimizer' : None if final_epoch else optimizer .state_dict ()}
253
+
254
+ # Save last, best and delete
255
+ torch .save (ckpt , last )
256
+ if best_fitness == fi :
257
+ torch .save (ckpt , best )
258
+ del ckpt
259
+ # end epoch ----------------------------------------------------------------------------------------------------
260
+ return imgsz
261
+ # end training
262
+
263
+
264
+
265
+
266
+ def main (data , model , args ):
267
+ opt = obj ({})
268
+ opt .total_batch_size = 16 if not hasattr (args , 'batch_size' ) else args .batchSize
269
+ opt .epochs = 300 if not hasattr (args , 'epochs' ) else args .epochs
270
+ opt .batch_size = opt .total_batch_size
271
+ opt .world_size = 1
272
+ opt .global_rank = - 1
273
+ opt .img_size = [640 , 640 ] if not hasattr (args , 'batch_size' ) else args .imgSize
274
+ opt .hyp = os .path .join (os .path .dirname (__file__ ), 'config/hyp.scratch.yaml' )
275
+ opt .device = ''
276
+ opt .weights = 'yolov5s.pt'
277
+ opt .single_cls = False
278
+
279
+ set_logging (opt .global_rank )
280
+
281
+ opt .img_size .extend ([opt .img_size [- 1 ]] * (2 - len (opt .img_size )))
282
+ device = select_device (opt .device , batch_size = opt .batch_size )
283
+ logger .info (opt )
284
+ with open (opt .hyp ) as f :
285
+ hyp = yaml .load (f , Loader = yaml .FullLoader )
286
+ dataconfig = preprocess (data )
287
+ model .cfg = obj ({})
288
+ model .cfg .data = opt .data = dataconfig
289
+
290
+ imgsz = train (model , data , hyp , opt , device )
291
+ model .cfg .imgsz = imgsz
292
+ sys .path .pop ()
293
+ return model
0 commit comments